diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/AsyncTableFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/AsyncTableFunction.java new file mode 100644 index 0000000000000..0f609a56945e5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/AsyncTableFunction.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions; + +import org.apache.flink.api.common.functions.InvalidTypesException; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.table.api.ValidationException; + +/** + * Base class for a user-defined asynchronously table function (UDTF). This is similar to + * {@link TableFunction} but this function is asynchronously. + * + *

A user-defined table functions works on + * zero, one, or multiple scalar values as input and returns multiple rows as output. + * + *

The behavior of a {@link AsyncTableFunction} can be defined by implementing a custom evaluation + * method. An evaluation method must be declared publicly, not static and named "eval". + * Evaluation methods can also be overloaded by implementing multiple methods named "eval". + * + *

The first parameter of evaluation method must be {@link ResultFuture}, and the others are user + * defined input parameters like the "eval" method of {@link TableFunction}. + * + *

For each "eval", an async io operation can be triggered, and once it has been done, + * the result can be collected by calling {@link ResultFuture#complete}. For each async + * operation, its context is stored in the operator immediately after invoking "eval", + * avoiding blocking for each stream input as long as the internal buffer is not full. + * + *

{@link ResultFuture} can be passed into callbacks or futures to collect the result data. + * An error can also be propagate to the async IO operator by + * {@link ResultFuture#completeExceptionally(Throwable)}. + * + *

User-defined functions must have a default constructor and must be instantiable during + * runtime. + * + *

By default the result type of an evaluation method is determined by Flink's type extraction + * facilities. This is sufficient for basic types or simple POJOs but might be wrong for more + * complex, custom, or composite types. In these cases {@link TypeInformation} of the result type + * can be manually defined by overriding {@link #getResultType}. + * + *

Internally, the Table/SQL API code generation works with primitive values as much as possible. + * If a user-defined table function should not introduce much overhead during runtime, it is + * recommended to declare parameters and result types as primitive types instead of their boxed + * classes. DATE/TIME is equal to int, TIMESTAMP is equal to long. + * + *

Example: + * + * {@code + * + * public class HBaseAsyncTableFunction extends AsyncTableFunction { + * + * // implement an "eval" method with as many parameters as you want + * public void eval(ResultFuture result, String rowkey) { + * Get get = new Get(Bytes.toBytes(rowkey)); + * ListenableFuture future = hbase.asyncGet(get); + * Futures.addCallback(future, new FutureCallback() { + * public void onSuccess(Result result) { + * List ret = process(result); + * result.complete(ret); + * } + * public void onFailure(Throwable thrown) { + * result.completeExceptionally(thrown); + * } + * }); + * } + * + * // you can overload the eval method here ... + * } + * } + * + *

NOTE: the {@link AsyncTableFunction} is can not used as UDTF currently. It only used in + * temporal table join as a async lookup function + * + * @param The type of the output row + */ +public abstract class AsyncTableFunction extends UserDefinedFunction { + + /** + * Returns the result type of the evaluation method with a given signature. + * + *

This method needs to be overridden in case Flink's type extraction facilities are not + * sufficient to extract the {@link TypeInformation} based on the return type of the evaluation + * method. Flink's type extraction facilities can handle basic types or + * simple POJOs but might be wrong for more complex, custom, or composite types. + * + * @return {@link TypeInformation} of result type or null if Flink should determine the type + */ + public TypeInformation getResultType() { + return null; + } + + /** + * Returns {@link TypeInformation} about the operands of the evaluation method with a given + * signature. + * + *

In order to perform operand type inference in SQL (especially when NULL is used) it might be + * necessary to determine the parameter {@link TypeInformation} of an evaluation method. + * By default Flink's type extraction facilities are used for this but might be wrong for + * more complex, custom, or composite types. + * + * @param signature signature of the method the operand types need to be determined + * @return {@link TypeInformation} of operand types + */ + public TypeInformation[] getParameterTypes(Class[] signature) { + final TypeInformation[] types = new TypeInformation[signature.length]; + for (int i = 0; i < signature.length; i++) { + try { + types[i] = TypeExtractor.getForClass(signature[i]); + } catch (InvalidTypesException e) { + throw new ValidationException( + "Parameter types of table function " + this.getClass().getCanonicalName() + + " cannot be automatically determined. Please provide type information manually."); + } + } + return types; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java index 78b0ed2491f9c..b88439096b735 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java @@ -46,7 +46,7 @@ public ProctimeSqlFunction() { private static class ProctimeRelProtoDataType implements RelProtoDataType { @Override public RelDataType apply(RelDataTypeFactory factory) { - return ((FlinkTypeFactory) factory).createRowtimeIndicatorType(); + return ((FlinkTypeFactory) factory).createProctimeIndicatorType(); } } } diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkJoinPushExpressionsRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkJoinPushExpressionsRule.java index 7e6b9601afdaa..07335e7c83caa 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkJoinPushExpressionsRule.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkJoinPushExpressionsRule.java @@ -32,6 +32,7 @@ * This rules is copied from Calcite's {@link org.apache.calcite.rel.rules.JoinPushExpressionsRule}. * Modification: * - Supports SEMI/ANTI join using {@link org.apache.flink.table.plan.util.FlinkRelOptUtil#pushDownJoinConditions} + * - Only push down calls on non-time-indicator field. */ /** diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/DefinedIndexes.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/DefinedIndexes.java new file mode 100644 index 0000000000000..38addccb900da --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/DefinedIndexes.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sources; + +import java.util.Collection; + +/** + * The {@link DefinedIndexes} interface can extends a {@link TableSource} to specify the + * indexes meta information. + * + *

An Index can be a Unique Index or Normal Index. An Unique Index is similar to primary + * key which defines a column or a group of columns that uniquely identifies each row in + * a table or stream. An Normal Index is an index on the defined columns used to accelerate + * querying. + */ +public interface DefinedIndexes { + + /** + * Returns the list of {@link TableIndex}s. Returns empty collection or null if no + * index is exist. + */ + Collection getIndexes(); + +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/DefinedPrimaryKey.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/DefinedPrimaryKey.java new file mode 100644 index 0000000000000..8794efd751852 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/DefinedPrimaryKey.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sources; + +import javax.annotation.Nullable; + +import java.util.List; + +/** + * The {@link DefinedPrimaryKey} interface can extends a {@link TableSource} to specify the + * primary key meta information. + * + *

A primary key is a column or a group of columns that uniquely identifies each row in + * a table or stream. + * + *

NOTE: Although a primary key usually has an Unique Index, if you have defined + * a primary key, there is no need to define a same index in {@link DefinedIndexes} again. + */ +public interface DefinedPrimaryKey { + + /** + * Returns the column names of the primary key. Returns null if no primary key existed + * in the {@link TableSource}. + */ + @Nullable + List getPrimaryKeyColumns(); + +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/LookupConfig.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/LookupConfig.java new file mode 100644 index 0000000000000..adb84a31fcce9 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/LookupConfig.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sources; + +/** + * The {@link LookupConfig} is used to configure some behavior when lookup a table. + * + * @see LookupableTableSource#getLookupConfig() + */ +public class LookupConfig { + + public static final LookupConfig DEFAULT = LookupConfig.builder().build(); + + private static final boolean DEFAULT_ASYNC_ENABLED = false; + private static final long DEFAULT_ASYNC_TIMEOUT_MS = 180_000; + private static final int DEFAULT_ASYNC_BUFFER_CAPACITY = 100; + + private final boolean asyncEnabled; + private final long asyncTimeoutMs; + private final int asyncBufferCapacity; + + private LookupConfig(boolean asyncEnabled, long asyncTimeoutMs, int asyncBufferCapacity) { + this.asyncEnabled = asyncEnabled; + this.asyncTimeoutMs = asyncTimeoutMs; + this.asyncBufferCapacity = asyncBufferCapacity; + } + + /** + * Returns true if async lookup is enabled. + */ + public boolean isAsyncEnabled() { + return asyncEnabled; + } + + /** + * Returns async timeout millisecond for the asynchronous operation to complete. + */ + public long getAsyncTimeoutMs() { + return asyncTimeoutMs; + } + + /** + * Returns the max number of async i/o operation that can be triggered. + */ + public int getAsyncBufferCapacity() { + return asyncBufferCapacity; + } + + /** + * Returns a new builder that builds a {@link LookupConfig}. + * + *

For example: + * + *

+	 *     LookupConfig.builder()
+	 *       .setAsyncEnabled(true)
+	 *       .setAsyncBufferCapacity(1000)
+	 *       .setAsyncTimeoutMs(30000)
+	 *       .build();
+	 * 
+ */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder used to build a new {@link LookupConfig}. + */ + public static class Builder { + + private boolean asyncEnabled = DEFAULT_ASYNC_ENABLED; + private long asyncTimeoutMs = DEFAULT_ASYNC_TIMEOUT_MS; + private int asyncBufferCapacity = DEFAULT_ASYNC_BUFFER_CAPACITY; + + public Builder setAsyncEnabled(boolean asyncEnabled) { + this.asyncEnabled = asyncEnabled; + return this; + } + + public Builder setAsyncTimeoutMs(long timeoutMs) { + this.asyncTimeoutMs = timeoutMs; + return this; + } + + public Builder setAsyncBufferCapacity(int bufferCapacity) { + this.asyncBufferCapacity = bufferCapacity; + return this; + } + + public LookupConfig build() { + return new LookupConfig(asyncEnabled, asyncTimeoutMs, asyncBufferCapacity); + } + + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/LookupableTableSource.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/LookupableTableSource.java new file mode 100644 index 0000000000000..2f38e63b1739c --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/LookupableTableSource.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sources; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.functions.AsyncTableFunction; +import org.apache.flink.table.functions.TableFunction; + +/** + * A {@link TableSource} which supports for lookup accessing via key column(s). + * For example, MySQL TableSource can implement this interface to support lookup accessing. + * When temporal join this MySQL table, the runtime behavior could be in a lookup fashion. + * + * @param type of the result + */ +@PublicEvolving +public interface LookupableTableSource extends TableSource { + + /** + * Gets the {@link TableFunction} which supports lookup one key at a time. + * @param lookupKeys the chosen field names as lookup keys, it is in the defined order + */ + TableFunction getLookupFunction(String[] lookupKeys); + + /** + * Gets the {@link AsyncTableFunction} which supports async lookup one key at a time. + * @param lookupKeys the chosen field names as lookup keys, it is in the defined order + */ + AsyncTableFunction getAsyncLookupFunction(String[] lookupKeys); + + /** + * Defines the lookup behavior in the config. Such as whether to use async lookup. + */ + LookupConfig getLookupConfig(); +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/TableIndex.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/TableIndex.java new file mode 100644 index 0000000000000..1495e2cb15009 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sources/TableIndex.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sources; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * An Index meta information of a Table. + */ +public class TableIndex { + + /** + * Index type, currently only support NORMAL INDEX, and UNIQUE INDEX. + */ + public enum IndexType { + NORMAL, + UNIQUE + } + + private final String indexName; + private final IndexType indexType; + private final List indexedColumns; + private final String indexComment; + + private TableIndex(String indexName, IndexType indexType, List indexedColumns, String indexComment) { + this.indexName = indexName; + this.indexType = indexType; + this.indexedColumns = indexedColumns; + this.indexComment = indexComment; + } + + /** + * Returns name of the Index. + * + * @return an optional name of the index. + */ + public Optional getIndexName() { + return Optional.ofNullable(indexName); + } + + /** + * Returns the column names of the index. + */ + public List getIndexedColumns() { + return indexedColumns; + } + + /** + * Returns the type of the index. + */ + public IndexType getIndexType() { + return indexType; + } + + /** + * Returns comment of the index. + * @return an optional comment of the index. + */ + public Optional getIndexComment() { + return Optional.ofNullable(indexComment); + } + + /** + * Returns a new builder that builds a {@link TableIndex}. + * + *

For example: + *

+	 *     TableIndex.builder()
+	 *       .uniqueIndex()
+	 *       .indexedColumns("user_id", "user_name")
+	 *       .name("idx_1")
+	 *       .build();
+	 * 
+ */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder used to construct a {@link TableIndex}. + */ + public static class Builder { + private String indexName; + private IndexType indexType; + private List indexedColumns; + private String indexComment; + + public Builder normalIndex() { + checkState(indexType == null, "IndexType has been set."); + this.indexType = IndexType.NORMAL; + return this; + } + + public Builder uniqueIndex() { + checkState(indexType == null, "IndexType has been set."); + this.indexType = IndexType.UNIQUE; + return this; + } + + public Builder name(String name) { + this.indexName = name; + return this; + } + + public Builder indexedColumns(List indexedColumns) { + this.indexedColumns = indexedColumns; + return this; + } + + public Builder indexedColumns(String... indexedColumns) { + this.indexedColumns = Arrays.asList(indexedColumns); + return this; + } + + public Builder comment(String comment) { + this.indexComment = comment; + return this; + } + + public TableIndex build() { + checkNotNull(indexedColumns); + checkNotNull(indexType); + return new TableIndex(indexName, indexType, indexedColumns, indexComment); + } + + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala index 2b38b31f5587c..3668ecbb3c1f2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala @@ -167,6 +167,10 @@ class BatchTableEnvironment( def explain(table: Table, extended: Boolean): String = { val ast = table.asInstanceOf[TableImpl].getRelNode val optimizedNode = optimize(ast) + // translate plan to physical operators + val optimizedNodes = translateNodeDag(Seq(optimizedNode)) + require(optimizedNodes.size() == 1) + translateToPlan(optimizedNodes.head) val explainLevel = if (extended) { SqlExplainLevel.ALL_ATTRIBUTES diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala index 48442653d33a9..0f9d68c8bbfc4 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala @@ -215,6 +215,10 @@ abstract class StreamTableEnvironment( def explain(table: Table, extended: Boolean): String = { val ast = table.asInstanceOf[TableImpl].getRelNode val optimizedNode = optimize(ast) + // translate plan to physical operators + val optimizedNodes = translateNodeDag(Seq(optimizedNode)) + require(optimizedNodes.size() == 1) + translateToPlan(optimizedNodes.head) val explainLevel = if (extended) { SqlExplainLevel.ALL_ATTRIBUTES diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CalcCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CalcCodeGenerator.scala index fec7397639620..c93ecbb007fb7 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CalcCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CalcCodeGenerator.scala @@ -22,9 +22,10 @@ import org.apache.flink.table.`type`.{RowType, TypeConverters} import org.apache.flink.table.api.{TableConfig, TableException} import org.apache.flink.table.dataformat.{BaseRow, BoxedWrapperRow} import org.apache.flink.table.runtime.CodeGenOperatorFactory - import org.apache.calcite.plan.RelOptCluster import org.apache.calcite.rex._ +import org.apache.flink.api.common.functions.{FlatMapFunction, Function} +import org.apache.flink.table.generated.GeneratedFunction import scala.collection.JavaConversions._ @@ -53,6 +54,7 @@ object CalcCodeGenerator { config, calcProgram, condition, + eagerInputUnboxingCode = true, retainHeader = retainHeader) val genOperator = @@ -69,6 +71,42 @@ object CalcCodeGenerator { new CodeGenOperatorFactory(genOperator) } + private[flink] def generateFunction[T <: Function]( + inputType: RowType, + name: String, + returnType: RowType, + outRowClass: Class[_ <: BaseRow], + calcProjection: RexProgram, + calcCondition: Option[RexNode], + config: TableConfig): GeneratedFunction[FlatMapFunction[BaseRow, BaseRow]] = { + val ctx = CodeGeneratorContext(config) + val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM + val collectorTerm = CodeGenUtils.DEFAULT_COLLECTOR_TERM + val processCode = generateProcessCode( + ctx, + inputType, + returnType, + outRowClass, + returnType.getFieldNames, + config, + calcProjection, + calcCondition, + collectorTerm = collectorTerm, + eagerInputUnboxingCode = false, + outputDirectly = true + ) + + FunctionCodeGenerator.generateFunction( + ctx, + name, + classOf[FlatMapFunction[BaseRow, BaseRow]], + processCode, + returnType, + inputType, + input1Term = inputTerm, + collectorTerm = collectorTerm) + } + private[flink] def generateProcessCode( ctx: CodeGeneratorContext, inputType: RowType, @@ -80,6 +118,7 @@ object CalcCodeGenerator { condition: Option[RexNode], inputTerm: String = CodeGenUtils.DEFAULT_INPUT1_TERM, collectorTerm: String = CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM, + eagerInputUnboxingCode: Boolean, retainHeader: Boolean = false, outputDirectly: Boolean = false): String = { @@ -135,9 +174,9 @@ object CalcCodeGenerator { throw new TableException("This calc has no useful projection and no filter. " + "It should be removed by CalcRemoveRule.") } else if (condition.isEmpty) { // only projection - val projectionCode = produceProjectionCode + val projectionCode = produceProjectionCode s""" - |${ctx.reuseInputUnboxingCode()} + |${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""} |$projectionCode |""".stripMargin } else { @@ -145,14 +184,14 @@ object CalcCodeGenerator { // only filter if (onlyFilter) { s""" - |${ctx.reuseInputUnboxingCode()} + |${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""} |${filterCondition.code} |if (${filterCondition.resultTerm}) { | ${produceOutputCode(inputTerm)} |} |""".stripMargin } else { // both filter and projection - val filterInputCode = ctx.reuseInputUnboxingCode() + val filterInputCode = ctx.reuseInputUnboxingCode() val filterInputSet = Set(ctx.reusableInputUnboxingExprs.keySet.toSeq: _*) // if any filter conditions, projection code will enter an new scope @@ -162,10 +201,10 @@ object CalcCodeGenerator { .filter(entry => !filterInputSet.contains(entry._1)) .values.map(_.code).mkString("\n") s""" - |$filterInputCode + |${if (eagerInputUnboxingCode) filterInputCode else ""} |${filterCondition.code} |if (${filterCondition.resultTerm}) { - | $projectionInputCode + | ${if (eagerInputUnboxingCode) projectionInputCode else ""} | $projectionCode |} |""".stripMargin diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CollectorCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CollectorCodeGenerator.scala index 098931c7969eb..2beab0dd85a1b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CollectorCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CollectorCodeGenerator.scala @@ -17,8 +17,8 @@ */ package org.apache.flink.table.codegen +import org.apache.flink.configuration.Configuration import org.apache.flink.table.`type`.InternalType -import org.apache.flink.table.api.TableConfig import org.apache.flink.table.codegen.CodeGenUtils._ import org.apache.flink.table.codegen.Indenter.toISC import org.apache.flink.table.generated.GeneratedCollector @@ -48,29 +48,31 @@ object CollectorCodeGenerator { bodyCode: String, inputType: InternalType, collectedType: InternalType, - config: TableConfig, inputTerm: String = CodeGenUtils.DEFAULT_INPUT1_TERM, collectedTerm: String = CodeGenUtils.DEFAULT_INPUT2_TERM, converter: String => String = (a) => a): GeneratedCollector[TableFunctionCollector[_]] = { - val className = newName(name) + val funcName = newName(name) val input1TypeClass = boxedTypeTermForType(inputType) val input2TypeClass = boxedTypeTermForType(collectedType) val funcCode = j""" - public class $className extends ${classOf[TableFunctionCollector[_]].getCanonicalName} { + public class $funcName extends ${classOf[TableFunctionCollector[_]].getCanonicalName} { ${ctx.reuseMemberCode()} - public $className() throws Exception { + public $funcName() throws Exception { ${ctx.reuseInitCode()} + } + + @Override + public void open(${className[Configuration]} parameters) throws Exception { ${ctx.reuseOpenCode()} } @Override public void collect(Object record) throws Exception { - super.collect(record); $input1TypeClass $inputTerm = ($input1TypeClass) getInput(); $input2TypeClass $collectedTerm = ($input2TypeClass) ${converter("record")}; ${ctx.reuseLocalVariableCode()} @@ -81,11 +83,16 @@ object CollectorCodeGenerator { @Override public void close() { + try { + ${ctx.reuseCloseCode()} + } catch (Exception e) { + throw new RuntimeException(e); + } } } """.stripMargin - new GeneratedCollector(className, funcCode, ctx.references.toArray) + new GeneratedCollector(funcName, funcCode, ctx.references.toArray) } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CorrelateCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CorrelateCodeGenerator.scala index 19c356ce8edd4..9f2ba274c3896 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CorrelateCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CorrelateCodeGenerator.scala @@ -18,8 +18,12 @@ package org.apache.flink.table.codegen +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rex._ +import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.functions.Function import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo @@ -38,10 +42,6 @@ import org.apache.flink.table.runtime.CodeGenOperatorFactory import org.apache.flink.table.runtime.collector.TableFunctionCollector import org.apache.flink.table.runtime.util.StreamRecordCollector -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rex._ -import org.apache.calcite.sql.SemiJoinType - import scala.collection.JavaConversions._ object CorrelateCodeGenerator { @@ -172,6 +172,8 @@ object CorrelateCodeGenerator { val openUDTFCollector = s""" |$udtfCollectorTerm = new ${udtfCollector.getClassName}(); + |$udtfCollectorTerm.setRuntimeContext(getRuntimeContext()); + |$udtfCollectorTerm.open(new ${className[Configuration]}()); |$udtfCollectorTerm.setCollector( | new ${classOf[StreamRecordCollector[_]].getCanonicalName}( | ${CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM })); @@ -205,7 +207,7 @@ object CorrelateCodeGenerator { |boolean hasOutput = $udtfCollectorTerm.isCollected(); |if (!hasOutput) { | $header - | $udtfCollectorTerm.getCollector().collect($nullRowTerm); + | $udtfCollectorTerm.outputResult($nullRowTerm); |} |""".stripMargin } else if (projectProgram.isDefined) { @@ -235,7 +237,7 @@ object CorrelateCodeGenerator { |if (!hasOutput) { | ${projectionExpression.code} | $header - | $udtfCollectorTerm.getCollector().collect($outputTerm); + | $udtfCollectorTerm.outputResult($outputTerm); |} |""".stripMargin @@ -258,7 +260,7 @@ object CorrelateCodeGenerator { |if (!hasOutput) { | $joinedRowTerm.replace(${exprGenerator.input1Term}, $nullRowTerm); | $header - | $udtfCollectorTerm.getCollector().collect($joinedRowTerm); + | $udtfCollectorTerm.outputResult($joinedRowTerm); |} |""".stripMargin @@ -345,7 +347,7 @@ object CorrelateCodeGenerator { s""" |${udtfResultExpr.code} |$header - |getCollector().collect(${udtfResultExpr.resultTerm}); + |outputResult(${udtfResultExpr.resultTerm}); """.stripMargin } else { val outputTerm = CodeGenUtils.newName("projectOut") @@ -370,7 +372,7 @@ object CorrelateCodeGenerator { s""" |$header |${projectionExpression.code} - |getCollector().collect(${projectionExpression.resultTerm}); + |outputResult(${projectionExpression.resultTerm}); """.stripMargin } } else { @@ -387,7 +389,7 @@ object CorrelateCodeGenerator { |${udtfResultExpr.code} |$joinedRowTerm.replace($inputTerm, ${udtfResultExpr.resultTerm}); |$header - |getCollector().collect($joinedRowTerm); + |outputResult($joinedRowTerm); """.stripMargin } @@ -414,7 +416,6 @@ object CorrelateCodeGenerator { collectorCode, inputType, udtfType, - config, inputTerm = inputTerm, collectedTerm = udtfInputTerm, converter = CodeGenUtils.genToInternal(ctx, udtfExternalType)) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/LookupJoinCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/LookupJoinCodeGenerator.scala new file mode 100644 index 0000000000000..67dfac6479c39 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/LookupJoinCodeGenerator.scala @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.codegen + +import org.apache.calcite.rex.{RexNode, RexProgram} +import org.apache.flink.api.common.functions.FlatMapFunction +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.async.AsyncFunction +import org.apache.flink.table.`type`.{InternalType, RowType, TypeConverters} +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.GenerateUtils._ +import org.apache.flink.table.codegen.Indenter.toISC +import org.apache.flink.table.dataformat.DataFormatConverters.RowConverter +import org.apache.flink.table.dataformat.{BaseRow, GenericRow, JoinedRow} +import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction} +import org.apache.flink.table.generated.{GeneratedCollector, GeneratedFunction, GeneratedResultFuture} +import org.apache.flink.table.plan.util.LookupJoinUtil.{ConstantLookupKey, FieldRefLookupKey, LookupKey} +import org.apache.flink.table.runtime.collector.{TableFunctionCollector, TableFunctionResultFuture} +import org.apache.flink.types.Row +import org.apache.flink.util.Collector + +import java.util + +object LookupJoinCodeGenerator { + + val ARRAY_LIST = className[util.ArrayList[_]] + + /** + * Generates a lookup function ([[TableFunction]]) + */ + def generateLookupFunction( + config: TableConfig, + typeFactory: FlinkTypeFactory, + inputType: InternalType, + returnType: InternalType, + tableReturnTypeInfo: TypeInformation[_], + lookupKeyInOrder: Array[Int], + // index field position -> lookup key + allLookupFields: Map[Int, LookupKey], + lookupFunction: TableFunction[_], + enableObjectReuse: Boolean) + : GeneratedFunction[FlatMapFunction[BaseRow, BaseRow]] = { + + val ctx = CodeGeneratorContext(config) + val (prepareCode, parameters) = prepareParameters( + ctx, + typeFactory, + inputType, + lookupKeyInOrder, + allLookupFields, + tableReturnTypeInfo.isInstanceOf[RowTypeInfo], + enableObjectReuse) + + val lookupFunctionTerm = ctx.addReusableFunction(lookupFunction) + val setCollectorCode = tableReturnTypeInfo match { + case rt: RowTypeInfo => + val converterCollector = new RowToBaseRowCollector(rt) + val term = ctx.addReusableObject(converterCollector, "collector") + s""" + |$term.setCollector($DEFAULT_COLLECTOR_TERM); + |$lookupFunctionTerm.setCollector($term); + """.stripMargin + case _ => + s"$lookupFunctionTerm.setCollector($DEFAULT_COLLECTOR_TERM);" + } + + val body = + s""" + |$prepareCode + |$setCollectorCode + |$lookupFunctionTerm.eval($parameters); + """.stripMargin + + FunctionCodeGenerator.generateFunction( + ctx, + "LookupFunction", + classOf[FlatMapFunction[BaseRow, BaseRow]], + body, + returnType, + inputType) + } + + /** + * Generates a async lookup function ([[AsyncTableFunction]]) + */ + def generateAsyncLookupFunction( + config: TableConfig, + typeFactory: FlinkTypeFactory, + inputType: InternalType, + returnType: InternalType, + tableReturnTypeInfo: TypeInformation[_], + lookupKeyInOrder: Array[Int], + allLookupFields: Map[Int, LookupKey], + asyncLookupFunction: AsyncTableFunction[_]) + : GeneratedFunction[AsyncFunction[BaseRow, AnyRef]] = { + + val ctx = CodeGeneratorContext(config) + val (prepareCode, parameters) = prepareParameters( + ctx, + typeFactory, + inputType, + lookupKeyInOrder, + allLookupFields, + tableReturnTypeInfo.isInstanceOf[RowTypeInfo], + fieldCopy = true) // always copy input field because of async buffer + + val lookupFunctionTerm = ctx.addReusableFunction(asyncLookupFunction) + + val body = + s""" + |$prepareCode + |$lookupFunctionTerm.eval($DEFAULT_COLLECTOR_TERM, $parameters); + """.stripMargin + + FunctionCodeGenerator.generateFunction( + ctx, + "LookupFunction", + classOf[AsyncFunction[BaseRow, AnyRef]], + body, + returnType, + inputType) + } + + /** + * Prepares parameters and returns (code, parameters) + */ + private def prepareParameters( + ctx: CodeGeneratorContext, + typeFactory: FlinkTypeFactory, + inputType: InternalType, + lookupKeyInOrder: Array[Int], + allLookupFields: Map[Int, LookupKey], + isExternalArgs: Boolean, + fieldCopy: Boolean): (String, String) = { + + val inputFieldExprs = for (i <- lookupKeyInOrder) yield { + allLookupFields.get(i) match { + case Some(ConstantLookupKey(dataType, literal)) => + generateLiteral(ctx, dataType, literal.getValue3) + case Some(FieldRefLookupKey(index)) => + generateInputAccess( + ctx, + inputType, + DEFAULT_INPUT1_TERM, + index, + nullableInput = false, + fieldCopy) + case None => + throw new CodeGenException("This should never happen!") + } + } + val codeAndArg = inputFieldExprs + .map { e => + val externalTypeInfo = TypeConverters.createExternalTypeInfoFromInternalType(e.resultType) + val bType = if (isExternalArgs) { + boxedTypeTermForExternalType(externalTypeInfo) + } else { + boxedTypeTermForType(e.resultType) + } + val assign = if (isExternalArgs) { + CodeGenUtils.genToExternal(ctx, externalTypeInfo, e.resultTerm) + } else { + e.resultTerm + } + val newTerm = newName("arg") + val code = + s""" + |$bType $newTerm = null; + |if (!${e.nullTerm}) { + | $newTerm = $assign; + |} + """.stripMargin + (code, newTerm) + } + (codeAndArg.map(_._1).mkString("\n"), codeAndArg.map(_._2).mkString(", ")) + } + + /** + * Generates collector for temporal join ([[Collector]]) + * + * Differs from CommonCorrelate.generateCollector which has no real condition because of + * FLINK-7865, here we should deal with outer join type when real conditions filtered result. + */ + def generateCollector( + ctx: CodeGeneratorContext, + inputType: RowType, + udtfTypeInfo: RowType, + resultType: RowType, + condition: Option[RexNode], + pojoFieldMapping: Option[Array[Int]], + retainHeader: Boolean = true): GeneratedCollector[TableFunctionCollector[BaseRow]] = { + + val inputTerm = DEFAULT_INPUT1_TERM + val udtfInputTerm = DEFAULT_INPUT2_TERM + + val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false) + .bindInput(udtfTypeInfo, inputTerm = udtfInputTerm, inputFieldMapping = pojoFieldMapping) + + val udtfResultExpr = exprGenerator.generateConverterResultExpression( + udtfTypeInfo, classOf[GenericRow]) + + val joinedRowTerm = CodeGenUtils.newName("joinedRow") + ctx.addReusableOutputRecord(resultType, classOf[JoinedRow], joinedRowTerm) + + val header = if (retainHeader) { + s"$joinedRowTerm.setHeader($inputTerm.getHeader());" + } else { + "" + } + + val body = + s""" + |${udtfResultExpr.code} + |$joinedRowTerm.replace($inputTerm, ${udtfResultExpr.resultTerm}); + |$header + |outputResult($joinedRowTerm); + """.stripMargin + + val collectorCode = if (condition.isEmpty) { + body + } else { + + val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false) + .bindInput(inputType, inputTerm) + .bindSecondInput(udtfTypeInfo, udtfInputTerm, pojoFieldMapping) + val filterCondition = filterGenerator.generateExpression(condition.get) + + s""" + |${filterCondition.code} + |if (${filterCondition.resultTerm}) { + | $body + |} + |""".stripMargin + } + + generateTableFunctionCollectorForJoinTable( + ctx, + "JoinTableFuncCollector", + collectorCode, + inputType, + udtfTypeInfo, + inputTerm = inputTerm, + collectedTerm = udtfInputTerm) + } + + /** + * The only differences against CollectorCodeGenerator.generateTableFunctionCollector is + * "super.collect" call is binding with collect join row in "body" code + */ + private def generateTableFunctionCollectorForJoinTable( + ctx: CodeGeneratorContext, + name: String, + bodyCode: String, + inputType: RowType, + collectedType: RowType, + inputTerm: String = DEFAULT_INPUT1_TERM, + collectedTerm: String = DEFAULT_INPUT2_TERM) + : GeneratedCollector[TableFunctionCollector[BaseRow]] = { + + val funcName = newName(name) + val input1TypeClass = boxedTypeTermForType(inputType) + val input2TypeClass = boxedTypeTermForType(collectedType) + + val funcCode = + s""" + public class $funcName extends ${classOf[TableFunctionCollector[_]].getCanonicalName} { + + ${ctx.reuseMemberCode()} + + public $funcName(Object[] references) throws Exception { + ${ctx.reuseInitCode()} + } + + @Override + public void open(${className[Configuration]} parameters) throws Exception { + ${ctx.reuseOpenCode()} + } + + @Override + public void collect(Object record) throws Exception { + $input1TypeClass $inputTerm = ($input1TypeClass) getInput(); + $input2TypeClass $collectedTerm = ($input2TypeClass) record; + ${ctx.reuseLocalVariableCode()} + ${ctx.reuseInputUnboxingCode()} + $bodyCode + } + + @Override + public void close() throws Exception { + ${ctx.reuseCloseCode()} + } + } + """.stripMargin + + new GeneratedCollector(funcName, funcCode, ctx.references.toArray) + } + + /** + * Generates a [[TableFunctionResultFuture]] that can be passed to Java compiler. + * + * @param config The TableConfig + * @param name Class name of the table function collector. Must not be unique but has to be a + * valid Java class identifier. + * @param leftInputType The type information of the element being collected + * @param collectedType The type information of the element collected by the collector + * @param condition The filter condition before collect elements + * @return instance of GeneratedCollector + */ + def generateTableAsyncCollector( + config: TableConfig, + name: String, + leftInputType: RowType, + collectedType: RowType, + condition: Option[RexNode]) + : GeneratedResultFuture[TableFunctionResultFuture[BaseRow]] = { + + val funcName = newName(name) + val input1TypeClass = boxedTypeTermForType(leftInputType) + val input2TypeClass = boxedTypeTermForType(collectedType) + val input1Term = DEFAULT_INPUT1_TERM + val input2Term = DEFAULT_INPUT2_TERM + val outTerm = "resultCollection" + + val ctx = CodeGeneratorContext(config) + + val body = if (condition.isEmpty) { + "getResultFuture().complete(records);" + } else { + val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false) + .bindInput(leftInputType, input1Term) + .bindSecondInput(collectedType, input2Term) + val filterCondition = filterGenerator.generateExpression(condition.get) + + s""" + |if (records == null || records.size() == 0) { + | getResultFuture().complete(java.util.Collections.emptyList()); + | return; + |} + |try { + | $input1TypeClass $input1Term = ($input1TypeClass) getInput(); + | $ARRAY_LIST $outTerm = new $ARRAY_LIST(); + | for (Object record : records) { + | $input2TypeClass $input2Term = ($input2TypeClass) record; + | ${ctx.reuseLocalVariableCode()} + | ${ctx.reuseInputUnboxingCode()} + | ${ctx.reusePerRecordCode()} + | ${filterCondition.code} + | if (${filterCondition.resultTerm}) { + | $outTerm.add(record); + | } + | } + | getResultFuture().complete($outTerm); + |} catch (Exception e) { + | getResultFuture().completeExceptionally(e); + |} + |""".stripMargin + } + + val funcCode = + j""" + public class $funcName extends ${classOf[TableFunctionResultFuture[_]].getCanonicalName} { + + ${ctx.reuseMemberCode()} + + public $funcName(Object[] references) throws Exception { + ${ctx.reuseInitCode()} + } + + @Override + public void open(${className[Configuration]} parameters) throws Exception { + ${ctx.reuseOpenCode()} + } + + @Override + public void complete(java.util.Collection records) throws Exception { + $body + } + + public void close() throws Exception { + ${ctx.reuseCloseCode()} + } + } + """.stripMargin + + new GeneratedResultFuture(funcName, funcCode, ctx.references.toArray) + } + + /** + * Generates calculate flatmap function for temporal join which is used + * to projection/filter the dimension table results + */ + def generateCalcMapFunction( + config: TableConfig, + calcProgram: Option[RexProgram], + tableSourceRowType: RowType) + : GeneratedFunction[FlatMapFunction[BaseRow, BaseRow]] = { + + val program = calcProgram.get + val condition = if (program.getCondition != null) { + Some(program.expandLocalRef(program.getCondition)) + } else { + None + } + CalcCodeGenerator.generateFunction( + tableSourceRowType, + "TableCalcMapFunction", + FlinkTypeFactory.toInternalRowType(program.getOutputRowType), + classOf[GenericRow], + program, + condition, + config) + } + + + // ---------------------------------------------------------------------------------------- + // Utility Classes + // ---------------------------------------------------------------------------------------- + + class RowToBaseRowCollector(rowTypeInfo: RowTypeInfo) + extends TableFunctionCollector[Row] with Serializable { + + private val converter = new RowConverter(rowTypeInfo) + + override def collect(record: Row): Unit = { + val result = converter.toInternal(record) + outputResult(result) + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnInterval.scala index a4aeeae39a53a..d75558c95c6e8 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/metadata/FlinkRelMdColumnInterval.scala @@ -527,7 +527,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) case agg: BatchExecSortAggregate if agg.isMerge => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( - aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) + aggCallIndex, agg.getAggCallList, agg.aggInputRowType) if (aggCallIndexInLocalAgg != null) { return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg) } else { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonCalc.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonCalc.scala index e14ed70f5d524..4f540f516c93f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonCalc.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonCalc.scala @@ -18,14 +18,14 @@ package org.apache.flink.table.plan.nodes.common -import org.apache.flink.table.plan.nodes.ExpressionFormat.ExpressionFormat -import org.apache.flink.table.plan.nodes.{ExpressionFormat, FlinkRelNode} - import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} import org.apache.calcite.rel.core.Calc import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexProgram} +import org.apache.flink.table.plan.nodes.ExpressionFormat.ExpressionFormat +import org.apache.flink.table.plan.nodes.{ExpressionFormat, FlinkRelNode} +import org.apache.flink.table.plan.util.RelExplainUtil.conditionToString import scala.collection.JavaConversions._ @@ -60,19 +60,9 @@ abstract class CommonCalc( override def explainTerms(pw: RelWriter): RelWriter = { pw.input("input", getInput) .item("select", projectionToString()) - .itemIf("where", conditionToString(), calcProgram.getCondition != null) - } - - protected def conditionToString(): String = { - val cond = calcProgram.getCondition - val inputFieldNames = calcProgram.getInputRowType.getFieldNames.toList - val localExprs = calcProgram.getExprList.toList - - if (cond != null) { - getExpressionString(cond, inputFieldNames, Some(localExprs)) - } else { - "" - } + .itemIf("where", + conditionToString(calcProgram, getExpressionString), + calcProgram.getCondition != null) } protected def projectionToString( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonLookupJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonLookupJoin.scala new file mode 100644 index 0000000000000..6d2e40dea649d --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonLookupJoin.scala @@ -0,0 +1,731 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes.common + +import com.google.common.primitives.Primitives +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} +import org.apache.calcite.rel.core.{JoinInfo, JoinRelType} +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.calcite.rex._ +import org.apache.calcite.sql.SqlKind +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.sql.validate.SqlValidatorUtil +import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.mapping.IntPair +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.{RowTypeInfo, TypeExtractor} +import org.apache.flink.streaming.api.datastream.AsyncDataStream.OutputMode +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.streaming.api.functions.async.ResultFuture +import org.apache.flink.streaming.api.operators.ProcessOperator +import org.apache.flink.streaming.api.operators.async.AsyncWaitOperator +import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} +import org.apache.flink.table.`type`._ +import org.apache.flink.table.api.{TableConfig, TableException, TableSchema} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.LookupJoinCodeGenerator._ +import org.apache.flink.table.codegen.{CodeGeneratorContext, LookupJoinCodeGenerator} +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getParamClassesConsiderVarArgs, getUserDefinedMethod, signatureToString, signaturesToString} +import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction, UserDefinedFunction} +import org.apache.flink.table.plan.nodes.FlinkRelNode +import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType +import org.apache.flink.table.plan.util.{JoinTypeUtil, RelExplainUtil} +import org.apache.flink.table.plan.util.LookupJoinUtil._ +import org.apache.flink.table.runtime.join.lookup.{AsyncLookupJoinRunner, LookupJoinRunner, AsyncLookupJoinWithCalcRunner, LookupJoinWithCalcRunner} +import org.apache.flink.table.sources.TableIndex.IndexType +import org.apache.flink.table.sources.{LookupConfig, LookupableTableSource, TableIndex, TableSource} +import org.apache.flink.table.typeutils.BaseRowTypeInfo +import org.apache.flink.types.Row + +import java.util.Collections + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * Common abstract RelNode for temporal table join which shares most methods. + * @param input input rel node + * @param tableSource the table source to be temporal joined + * @param tableRowType the row type of the table source + * @param calcOnTemporalTable the calc (projection&filter) after table scan before joining + */ +abstract class CommonLookupJoin( + cluster: RelOptCluster, + traitSet: RelTraitSet, + input: RelNode, + val tableSource: TableSource[_], + tableRowType: RelDataType, + val calcOnTemporalTable: Option[RexProgram], + val joinInfo: JoinInfo, + val joinType: JoinRelType) + extends SingleRel(cluster, traitSet, input) + with FlinkRelNode { + + val joinKeyPairs: Array[IntPair] = getTemporalTableJoinKeyPairs(joinInfo, calcOnTemporalTable) + val indexKeys: Array[TableIndex] = getTableIndexes(tableSource) + // all potential index keys, mapping from field index in table source to LookupKey + val allLookupKeys: Map[Int, LookupKey] = analyzeLookupKeys( + cluster.getRexBuilder, + joinKeyPairs, + indexKeys, + tableSource.getTableSchema, + calcOnTemporalTable) + // the matched best lookup fields which is in defined order, maybe empty + val matchedLookupFields: Option[Array[Int]] = findMatchedIndex( + indexKeys, + tableSource.getTableSchema, + allLookupKeys) + + override def deriveRowType(): RelDataType = { + val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val rightType = if (calcOnTemporalTable.isDefined) { + calcOnTemporalTable.get.getOutputRowType + } else { + tableRowType + } + SqlValidatorUtil.deriveJoinRowType( + input.getRowType, + rightType, + joinType, + flinkTypeFactory, + null, + Collections.emptyList[RelDataTypeField]) + } + + override def explainTerms(pw: RelWriter): RelWriter = { + val remaining = joinInfo.getRemaining(cluster.getRexBuilder) + val joinCondition = if (remaining.isAlwaysTrue) { + None + } else { + Some(remaining) + } + + val inputFieldNames = input.getRowType.getFieldNames.asScala.toArray + val tableFieldNames = tableSource.getTableSchema.getFieldNames + val rightFieldNames = calcOnTemporalTable match { + case Some(calc) => calc.getOutputRowType.getFieldNames.asScala.toArray + case None => tableFieldNames + } + val resultFieldNames = getRowType.getFieldNames.asScala.toArray + val lookupConfig = getLookupConfig(tableSource.asInstanceOf[LookupableTableSource[_]]) + val whereString = calcOnTemporalTable match { + case Some(calc) => RelExplainUtil.conditionToString(calc, getExpressionString) + case None => "N/A" + } + + super.explainTerms(pw) + .item("table", tableSource.explainSource()) + .item("joinType", JoinTypeUtil.getFlinkJoinType(joinType)) + .item("async", lookupConfig.isAsyncEnabled) + .item("on", joinOnToString(inputFieldNames, rightFieldNames, joinInfo)) + .itemIf("where", whereString, calcOnTemporalTable.isDefined) + .itemIf("joinCondition", + joinConditionToString(resultFieldNames, joinCondition), + joinCondition.isDefined) + .item("select", joinSelectionToString(getRowType)) + } + + // ---------------------------------------------------------------------------------------- + // Physical Translation + // ---------------------------------------------------------------------------------------- + + def translateToPlanInternal( + inputTransformation: StreamTransformation[BaseRow], + env: StreamExecutionEnvironment, + config: TableConfig, + relBuilder: RelBuilder): StreamTransformation[BaseRow] = { + + val inputRowType = FlinkTypeFactory.toInternalRowType(input.getRowType) + val tableSourceRowType = FlinkTypeFactory.toInternalRowType(tableRowType) + val resultRowType = FlinkTypeFactory.toInternalRowType(getRowType) + val tableSchema = tableSource.getTableSchema + + // validate whether the node is valid and supported. + validate( + tableSource, + inputRowType, + tableSourceRowType, + indexKeys, + allLookupKeys, + matchedLookupFields, + joinType) + + val lookupFieldsInOrder = matchedLookupFields.get + val lookupFieldNamesInOrder = lookupFieldsInOrder.map(tableSchema.getFieldNames()(_)) + val lookupFieldTypesInOrder = lookupFieldsInOrder + .map(tableSchema.getFieldTypes()(_)) + .map(TypeConverters.createInternalTypeFromTypeInfo) + val remainingCondition = getRemainingJoinCondition( + cluster.getRexBuilder, + relBuilder, + input.getRowType, + tableRowType, + calcOnTemporalTable, + lookupFieldsInOrder, + joinKeyPairs, + joinInfo, + allLookupKeys) + + val lookupableTableSource = tableSource.asInstanceOf[LookupableTableSource[_]] + val lookupConfig = getLookupConfig(lookupableTableSource) + val leftOuterJoin = joinType == JoinRelType.LEFT + + val operator = if (lookupConfig.isAsyncEnabled) { + val asyncBufferCapacity= lookupConfig.getAsyncBufferCapacity + val asyncTimeout = lookupConfig.getAsyncTimeoutMs + + val asyncLookupFunction = lookupableTableSource + .getAsyncLookupFunction(lookupFieldNamesInOrder) + // return type valid check + val udtfResultType = asyncLookupFunction.getResultType + val extractedResultTypeInfo = TypeExtractor.createTypeInfo( + asyncLookupFunction, + classOf[AsyncTableFunction[_]], + asyncLookupFunction.getClass, + 0) + checkUdtfReturnType( + tableSource.explainSource(), + tableSource.getReturnType, + udtfResultType, + extractedResultTypeInfo) + val parameters = Array(new GenericType(classOf[ResultFuture[_]])) ++ lookupFieldTypesInOrder + checkEvalMethodSignature( + asyncLookupFunction, + parameters, + extractedResultTypeInfo) + + val generatedFetcher = LookupJoinCodeGenerator.generateAsyncLookupFunction( + config, + relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory], + inputRowType, + resultRowType, + tableSource.getReturnType, + lookupFieldsInOrder, + allLookupKeys, + asyncLookupFunction) + + val asyncFunc = if (calcOnTemporalTable.isDefined) { + // a projection or filter after table source scan + val rightRowType = FlinkTypeFactory + .toInternalRowType(calcOnTemporalTable.get.getOutputRowType) + val generatedResultFuture = LookupJoinCodeGenerator.generateTableAsyncCollector( + config, + "TableFunctionResultFuture", + inputRowType, + rightRowType, + remainingCondition) + val generatedCalc = generateCalcMapFunction( + config, + calcOnTemporalTable, + tableSourceRowType) + + new AsyncLookupJoinWithCalcRunner( + generatedFetcher, + generatedCalc, + generatedResultFuture, + tableSource.getReturnType, + rightRowType.toTypeInfo, + leftOuterJoin, + lookupConfig.getAsyncBufferCapacity) + } else { + // right type is the same as table source row type, because no calc after temporal table + val rightRowType = tableSourceRowType + val generatedResultFuture = LookupJoinCodeGenerator.generateTableAsyncCollector( + config, + "TableFunctionResultFuture", + inputRowType, + rightRowType, + remainingCondition) + new AsyncLookupJoinRunner( + generatedFetcher, + generatedResultFuture, + tableSource.getReturnType, + rightRowType.toTypeInfo, + leftOuterJoin, + asyncBufferCapacity) + } + + // force ORDERED output mode currently, optimize it to UNORDERED + // when the downstream do not need orderness + new AsyncWaitOperator(asyncFunc, asyncTimeout, asyncBufferCapacity, OutputMode.ORDERED) + } else { + // sync join + val lookupFunction = lookupableTableSource.getLookupFunction(lookupFieldNamesInOrder) + // return type valid check + val udtfResultType = lookupFunction.getResultType + val extractedResultTypeInfo = TypeExtractor.createTypeInfo( + lookupFunction, + classOf[TableFunction[_]], + lookupFunction.getClass, + 0) + checkUdtfReturnType( + tableSource.explainSource(), + tableSource.getReturnType, + udtfResultType, + extractedResultTypeInfo) + checkEvalMethodSignature( + lookupFunction, + lookupFieldTypesInOrder, + extractedResultTypeInfo) + + val generatedFetcher = LookupJoinCodeGenerator.generateLookupFunction( + config, + relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory], + inputRowType, + resultRowType, + tableSource.getReturnType, + lookupFieldsInOrder, + allLookupKeys, + lookupFunction, + env.getConfig.isObjectReuseEnabled) + + val ctx = CodeGeneratorContext(config) + val processFunc = if (calcOnTemporalTable.isDefined) { + // a projection or filter after table source scan + val rightRowType = FlinkTypeFactory + .toInternalRowType(calcOnTemporalTable.get.getOutputRowType) + val generatedCollector = generateCollector( + ctx, + inputRowType, + rightRowType, + resultRowType, + remainingCondition, + None) + val generatedCalc = generateCalcMapFunction( + config, + calcOnTemporalTable, + tableSourceRowType) + + new LookupJoinWithCalcRunner( + generatedFetcher, + generatedCalc, + generatedCollector, + leftOuterJoin, + rightRowType.getArity) + } else { + // right type is the same as table source row type, because no calc after temporal table + val rightRowType = tableSourceRowType + val generatedCollector = generateCollector( + ctx, + inputRowType, + rightRowType, + resultRowType, + remainingCondition, + None) + new LookupJoinRunner( + generatedFetcher, + generatedCollector, + leftOuterJoin, + rightRowType.getArity) + } + new ProcessOperator(processFunc) + } + + new OneInputTransformation( + inputTransformation, + "LookupJoin", + operator, + resultRowType.toTypeInfo, + inputTransformation.getParallelism) + } + + def getLookupConfig(tableSource: LookupableTableSource[_]): LookupConfig = { + if (tableSource.getLookupConfig != null) { + tableSource.getLookupConfig + } else { + LookupConfig.DEFAULT + } + } + + private def rowTypeEquals(expected: TypeInformation[_], actual: TypeInformation[_]): Boolean = { + // check internal and external type, cause we will auto convert external class to internal + // class (eg: Row => BaseRow). + (expected.getTypeClass == classOf[BaseRow] || expected.getTypeClass == classOf[Row]) && + (actual.getTypeClass == classOf[BaseRow] || actual.getTypeClass == classOf[Row]) + } + + def checkEvalMethodSignature( + func: UserDefinedFunction, + expectedTypes: Array[InternalType], + udtfReturnType: TypeInformation[_]) + : Array[Class[_]] = { + val expectedTypeClasses = if (udtfReturnType.getTypeClass == classOf[Row]) { + expectedTypes.map(InternalTypeUtils.getExternalClassForType) + } else { + expectedTypes.map{ + case gt: GenericType[_] => gt.getTypeInfo.getTypeClass // special case for generic type + case t@_ => InternalTypeUtils.getInternalClassForType(t) + } + } + val method = getUserDefinedMethod( + func, + "eval", + expectedTypeClasses, + expectedTypes, + _ => expectedTypes.indices.map(_ => null).toArray, + parameterTypeEquals, + (_, _) => false).getOrElse { + val msg = s"Given parameter types of the lookup TableFunction of TableSource " + + s"[${tableSource.explainSource()}] do not match the expected signature.\n" + + s"Expected: eval${signatureToString(expectedTypeClasses)} \n" + + s"Actual: eval${signaturesToString(func, "eval")}" + throw new TableException(msg) + } + getParamClassesConsiderVarArgs(method.isVarArgs, + method.getParameterTypes, expectedTypes.length) + } + + private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean = { + candidate == null || + candidate == expected || + expected == classOf[Object] || + candidate == classOf[Object] || // Special case when we don't know the type + expected.isPrimitive && Primitives.wrap(expected) == candidate || + (candidate.isArray && + expected.isArray && + candidate.getComponentType.isInstanceOf[Object] && + expected.getComponentType == classOf[Object]) + } + + private def getRemainingJoinCondition( + rexBuilder: RexBuilder, + relBuilder: RelBuilder, + leftRelDataType: RelDataType, + tableRelDataType: RelDataType, + calcOnTemporalTable: Option[RexProgram], + checkedLookupFields: Array[Int], + joinKeyPairs: Array[IntPair], + joinInfo: JoinInfo, + allLookupKeys: Map[Int, LookupKey]): Option[RexNode] = { + val remainingPairs = joinKeyPairs.filter(p => !checkedLookupFields.contains(p.target)) + // convert remaining pairs to RexInputRef tuple for building sqlStdOperatorTable.EQUALS calls + val remainingAnds = remainingPairs.map { p => + val leftFieldType = leftRelDataType.getFieldList.get(p.source).getType + val leftInputRef = new RexInputRef(p.source, leftFieldType) + val rightInputRef = calcOnTemporalTable match { + case Some(program) => + val rightKeyIdx = program + .getOutputRowType.getFieldNames + .indexOf(program.getInputRowType.getFieldNames.get(p.target)) + new RexInputRef( + leftRelDataType.getFieldCount + rightKeyIdx, + program.getOutputRowType.getFieldList.get(rightKeyIdx).getType) + + case None => + new RexInputRef( + leftRelDataType.getFieldCount + p.target, + tableRelDataType.getFieldList.get(p.target).getType) + } + (leftInputRef, rightInputRef) + } + val equiAnds = relBuilder.and(remainingAnds.map(p => relBuilder.equals(p._1, p._2)): _*) + val condition = relBuilder.and(equiAnds, joinInfo.getRemaining(rexBuilder)) + if (condition.isAlwaysTrue) { + None + } else { + Some(condition) + } + } + + + /** + * Gets the join key pairs from left input field index to temporal table field index + * @param joinInfo the join information of temporal table join + * @param calcOnTemporalTable the calc programs on temporal table + */ + private def getTemporalTableJoinKeyPairs( + joinInfo: JoinInfo, + calcOnTemporalTable: Option[RexProgram]): Array[IntPair] = { + val joinPairs = joinInfo.pairs().asScala.toArray + calcOnTemporalTable match { + case Some(program) => + // the target key of joinInfo is the calc output fields, we have to remapping to table here + val keyPairs = new mutable.ArrayBuffer[IntPair]() + joinPairs.map { + p => + val calcSrcIdx = getIdenticalSourceField(program, p.target) + if (calcSrcIdx != -1) { + keyPairs += new IntPair(p.source, calcSrcIdx) + } + } + keyPairs.toArray + case None => joinPairs + } + } + + /** + * Analyze potential lookup keys (including [[ConstantLookupKey]] and [[FieldRefLookupKey]]) + * of the temporal table from the join condition and calc program on the temporal table. + * + * @param rexBuilder the RexBuilder + * @param joinKeyPairs join key pairs from left input field index to temporal table field index + * @param calcOnTemporalTable the calc program on temporal table + * @return all the potential lookup keys + */ + def analyzeLookupKeys( + rexBuilder: RexBuilder, + joinKeyPairs: Array[IntPair], + tableIndexes: Array[TableIndex], + temporalTableSchema: TableSchema, + calcOnTemporalTable: Option[RexProgram]): Map[Int, LookupKey] = { + val fieldNames = temporalTableSchema.getFieldNames + val allIndexFields = tableIndexes + .flatMap(_.getIndexedColumns.asScala.map(fieldNames.indexOf(_))) + .toSet + // field_index_in_table_source => constant_lookup_key + val constantLookupKeys = new mutable.HashMap[Int, ConstantLookupKey] + // analyze constant lookup keys + if (calcOnTemporalTable.isDefined && null != calcOnTemporalTable.get.getCondition) { + val program = calcOnTemporalTable.get + val condition = RexUtil.toCnf( + cluster.getRexBuilder, + program.expandLocalRef(program.getCondition)) + // presume 'A = 1 AND A = 2' will be reduced to ALWAYS_FALSE + extractConstantFieldsFromEquiCondition(condition, allIndexFields, constantLookupKeys) + } + val fieldRefLookupKeys = joinKeyPairs.map(p => (p.target, FieldRefLookupKey(p.source))) + (constantLookupKeys ++ fieldRefLookupKeys).toMap + } + + private def findMatchedIndex( + tableIndexes: Array[TableIndex], + temporalTableSchema: TableSchema, + allLookupKeys: Map[Int, LookupKey]): Option[Array[Int]] = { + + val fieldNames = temporalTableSchema.getFieldNames + + // [(indexFields, isUniqueIndex)] + val indexes: Array[(Array[Int], Boolean)] = tableIndexes.map { tableIndex => + val indexFields = tableIndex.getIndexedColumns.asScala.map(fieldNames.indexOf(_)).toArray + val isUniqueIndex = tableIndex.getIndexType.equals(IndexType.UNIQUE) + (indexFields, isUniqueIndex) + } + + val matchedIndexes = indexes.filter(_._1.forall(allLookupKeys.contains)) + if (matchedIndexes.length > 1) { + // find a best one, we prefer a unique index key here + val uniqueIndex = matchedIndexes.find(_._2).map(_._1) + if (uniqueIndex.isDefined) { + uniqueIndex + } else { + // all the matched index are normal index, select anyone from matched indexes + matchedIndexes.map(_._1).headOption + } + } else { + // select anyone from matched indexes + matchedIndexes.map(_._1).headOption + } + } + + // ---------------------------------------------------------------------------------------- + // Physical Optimization Utilities + // ---------------------------------------------------------------------------------------- + + // this is highly inspired by Calcite's RexProgram#getSourceField(int) + private def getIdenticalSourceField(rexProgram: RexProgram, outputOrdinal: Int): Int = { + assert((outputOrdinal >= 0) && (outputOrdinal < rexProgram.getProjectList.size())) + val project = rexProgram.getProjectList.get(outputOrdinal) + var index = project.getIndex + while (true) { + var expr = rexProgram.getExprList.get(index) + expr match { + case call: RexCall if call.getOperator == SqlStdOperatorTable.IN_FENNEL => + // drill through identity function + expr = call.getOperands.get(0) + case call: RexCall if call.getOperator == SqlStdOperatorTable.CAST => + // drill through identity function + expr = call.getOperands.get(0) + case _ => + } + expr match { + case ref: RexLocalRef => index = ref.getIndex + case ref: RexInputRef => return ref.getIndex + case _ => return -1 + } + } + -1 + } + + private def extractConstantFieldsFromEquiCondition( + condition: RexNode, + allIndexFields: Set[Int], + constantFieldMap: mutable.HashMap[Int, ConstantLookupKey]): Unit = condition match { + case c: RexCall if c.getKind == SqlKind.AND => + c.getOperands.asScala.foreach(r => extractConstantField(r, allIndexFields, constantFieldMap)) + case rex: RexNode => extractConstantField(rex, allIndexFields, constantFieldMap) + case _ => + } + + private def extractConstantField( + pred: RexNode, + allIndexFields: Set[Int], + constantFieldMap: mutable.HashMap[Int, ConstantLookupKey]): Unit = pred match { + case c: RexCall if c.getKind == SqlKind.EQUALS => + val left = c.getOperands.get(0) + val right = c.getOperands.get(1) + val (inputRef, literal) = (left, right) match { + case (literal: RexLiteral, ref: RexInputRef) => (ref, literal) + case (ref: RexInputRef, literal: RexLiteral) => (ref, literal) + } + if (allIndexFields.contains(inputRef.getIndex)) { + val dataType = FlinkTypeFactory.toInternalType(inputRef.getType) + constantFieldMap.put(inputRef.getIndex, ConstantLookupKey(dataType, literal)) + } + case _ => // ignore + } + + // ---------------------------------------------------------------------------------------- + // Validation + // ---------------------------------------------------------------------------------------- + + def validate( + tableSource: TableSource[_], + inputRowType: RowType, + tableSourceRowType: RowType, + tableIndexes: Array[TableIndex], + allLookupKeys: Map[Int, LookupKey], + matchedLookupFields: Option[Array[Int]], + joinType: JoinRelType): Unit = { + + // checked PRIMARY KEY or (UNIQUE) INDEX is defined. + if (tableIndexes.isEmpty) { + throw new TableException( + s"Temporal table join requires table [${tableSource.explainSource()}] defines " + + s"a PRIMARY KEY or (UNIQUE) INDEX.") + } + + // check join on all fields of PRIMARY KEY or (UNIQUE) INDEX + if (allLookupKeys.isEmpty || matchedLookupFields.isEmpty) { + throw new TableException( + "Temporal table join requires an equality condition on ALL fields of " + + s"table [${tableSource.explainSource()}]'s PRIMARY KEY or (UNIQUE) INDEX(s).") + } + + if (!tableSource.isInstanceOf[LookupableTableSource[_]]) { + throw new TableException(s"TableSource of [${tableSource.explainSource()}] must " + + s"implement LookupableTableSource interface if it is used in temporal table join.") + } + + val checkedLookupFields = matchedLookupFields.get + + val lookupKeyPairs = joinKeyPairs.filter(p => checkedLookupFields.contains(p.target)) + val leftKeys = lookupKeyPairs.map(_.source) + val rightKeys = lookupKeyPairs.map(_.target) + val leftKeyTypes = leftKeys.map(inputRowType.getFieldTypes()(_)) + // use original keyPair to validate key types (rigthKeys may include constant keys) + val rightKeyTypes = rightKeys.map(tableSourceRowType.getFieldTypes()(_)) + + // check type + val incompatibleConditions = new mutable.ArrayBuffer[String]() + for (i <- lookupKeyPairs.indices) { + val leftType = leftKeyTypes(i) + val rightType = rightKeyTypes(i) + if (leftType != rightType) { + val leftName = inputRowType.getFieldNames()(i) + val rightName = tableSourceRowType.getFieldNames()(i) + val condition = s"$leftName[$leftType]=$rightName[$rightType]" + incompatibleConditions += condition + } + } + if (incompatibleConditions.nonEmpty) { + throw new TableException(s"Temporal table join requires equivalent condition " + + s"of the same type, but the condition is ${incompatibleConditions.mkString(", ")}") + } + + if (joinType != JoinRelType.LEFT && joinType != JoinRelType.INNER) { + throw new TableException( + "Temporal table join currently only support INNER JOIN and LEFT JOIN, " + + "but was " + joinType.toString + " JOIN") + } + + val tableReturnType = tableSource.getReturnType + if (!tableReturnType.isInstanceOf[BaseRowTypeInfo] && + !tableReturnType.isInstanceOf[RowTypeInfo]) { + throw new TableException( + "Temporal table join only support Row or BaseRow type as return type of temporal table." + + " But was " + tableReturnType) + } + + // success + } + + def checkUdtfReturnType( + tableDesc: String, + tableReturnTypeInfo: TypeInformation[_], + udtfReturnTypeInfo: TypeInformation[_], + extractedUdtfReturnTypeInfo: TypeInformation[_]): Unit = { + if (udtfReturnTypeInfo == null) { + if (!rowTypeEquals(tableReturnTypeInfo, extractedUdtfReturnTypeInfo)) { + throw new TableException( + s"The TableSource [$tableDesc] return type $tableReturnTypeInfo does not match " + + s"its lookup function extracted return type $extractedUdtfReturnTypeInfo") + } + if (extractedUdtfReturnTypeInfo.getTypeClass != classOf[BaseRow] && + extractedUdtfReturnTypeInfo.getTypeClass != classOf[Row]) { + throw new TableException( + s"Result type of the lookup TableFunction of TableSource [$tableDesc] is " + + s"$extractedUdtfReturnTypeInfo type, " + + s"but currently only Row and BaseRow are supported.") + } + } else { + if (!rowTypeEquals(tableReturnTypeInfo, udtfReturnTypeInfo)) { + throw new TableException( + s"The TableSource [$tableDesc] return type $tableReturnTypeInfo " + + s"does not match its lookup function return type $udtfReturnTypeInfo") + } + if (!udtfReturnTypeInfo.isInstanceOf[BaseRowTypeInfo] && + !udtfReturnTypeInfo.isInstanceOf[RowTypeInfo]) { + throw new TableException( + "Result type of the async lookup TableFunction of TableSource " + + s"'$tableDesc' is $udtfReturnTypeInfo type, " + + s"currently only Row and BaseRow are supported.") + } + } + } + + // ---------------------------------------------------------------------------------------- + // toString Utilities + // ---------------------------------------------------------------------------------------- + + private def joinSelectionToString(resultType: RelDataType): String = { + resultType.getFieldNames.asScala.toList.mkString(", ") + } + + private def joinConditionToString( + resultFieldNames: Array[String], + joinCondition: Option[RexNode]): String = joinCondition match { + case Some(condition) => + getExpressionString(condition, resultFieldNames.toList, None) + case None => "N/A" + } + + private def joinOnToString( + inputFieldNames: Array[String], + tableFieldNames: Array[String], + joinInfo: JoinInfo): String = { + val keyPairNames = joinInfo.pairs().asScala.map { p => + s"${inputFieldNames(p.source)}=${ + if (p.target >= 0 && p.target < tableFieldNames.length) tableFieldNames(p.target) else -1 + }" + } + keyPairNames.mkString(", ") + } +} + diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalSnapshot.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalSnapshot.scala new file mode 100644 index 0000000000000..294d0f78a5281 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalSnapshot.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes.logical + +import org.apache.flink.table.plan.nodes.FlinkConventions + +import org.apache.calcite.plan._ +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.rel.core.Snapshot +import org.apache.calcite.rel.logical.LogicalSnapshot +import org.apache.calcite.rel.metadata.{RelMdCollation, RelMetadataQuery} +import org.apache.calcite.rel.{RelCollation, RelCollationTraitDef, RelNode} +import org.apache.calcite.rex.RexNode + +import java.util +import java.util.function.Supplier + +/** + * Sub-class of [[Snapshot]] that is a relational expression which returns + * the contents of a relation expression as it was at a given time in the past. + */ +class FlinkLogicalSnapshot( + cluster: RelOptCluster, + traits: RelTraitSet, + child: RelNode, + period: RexNode) + extends Snapshot(cluster, traits, child, period) + with FlinkLogicalRel { + + override def copy( + traitSet: RelTraitSet, + input: RelNode, + period: RexNode): Snapshot = { + new FlinkLogicalSnapshot(cluster, traitSet, input, period) + } + + override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = { + val rowCnt = mq.getRowCount(this) + val rowSize = mq.getAverageRowSize(this) + planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * rowSize) + } + +} + +class FlinkLogicalSnapshotConverter + extends ConverterRule( + classOf[LogicalSnapshot], + Convention.NONE, + FlinkConventions.LOGICAL, + "FlinkLogicalSnapshotConverter") { + + def convert(rel: RelNode): RelNode = { + val snapshot = rel.asInstanceOf[LogicalSnapshot] + val newInput = RelOptRule.convert(snapshot.getInput, FlinkConventions.LOGICAL) + FlinkLogicalSnapshot.create(newInput, snapshot.getPeriod) + } +} + +object FlinkLogicalSnapshot { + + val CONVERTER = new FlinkLogicalSnapshotConverter + + def create(input: RelNode, period: RexNode): FlinkLogicalSnapshot = { + val cluster = input.getCluster + val mq = cluster.getMetadataQuery + val traitSet = cluster.traitSet.replace(Convention.NONE).replaceIfs( + RelCollationTraitDef.INSTANCE, new Supplier[util.List[RelCollation]]() { + def get: util.List[RelCollation] = RelMdCollation.snapshot(mq, input) + }) + val snapshot = new FlinkLogicalSnapshot(cluster, traitSet, input, period) + val newTraitSet = snapshot.getTraitSet + .replace(FlinkConventions.LOGICAL).simplify() + snapshot.copy(newTraitSet, input, period).asInstanceOf[FlinkLogicalSnapshot] + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLookupJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLookupJoin.scala new file mode 100644 index 0000000000000..694db12db075c --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLookupJoin.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes.physical.batch + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.{JoinInfo, JoinRelType} +import org.apache.calcite.rex.RexProgram +import org.apache.flink.runtime.operators.DamBehavior +import org.apache.flink.streaming.api.transformations.StreamTransformation +import org.apache.flink.table.api.{BatchTableEnvironment, TableConfigOptions} +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.common.CommonLookupJoin +import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} +import org.apache.flink.table.sources.TableSource + +import java.util + +import scala.collection.JavaConversions._ + +/** + * Batch physical RelNode for temporal table join. + */ +class BatchExecLookupJoin( + cluster: RelOptCluster, + traitSet: RelTraitSet, + input: RelNode, + tableSource: TableSource[_], + tableRowType: RelDataType, + tableCalcProgram: Option[RexProgram], + joinInfo: JoinInfo, + joinType: JoinRelType) + extends CommonLookupJoin( + cluster, + traitSet, + input, + tableSource, + tableRowType, + tableCalcProgram, + joinInfo, + joinType) + with BatchPhysicalRel + with BatchExecNode[BaseRow] { + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new BatchExecLookupJoin( + cluster, + traitSet, + inputs.get(0), + tableSource, + tableRowType, + tableCalcProgram, + joinInfo, + joinType) + } + + //~ ExecNode methods ----------------------------------------------------------- + + override def getDamBehavior: DamBehavior = DamBehavior.PIPELINED + + override def getInputNodes: util.List[ExecNode[BatchTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[BatchTableEnvironment, _]]) + } + + override def replaceInputNode( + ordinalInParent: Int, + newInputNode: ExecNode[BatchTableEnvironment, _]): Unit = { + replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode]) + } + + override protected def translateToPlanInternal( + tableEnv: BatchTableEnvironment): StreamTransformation[BaseRow] = { + + val inputTransformation = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + val defaultParallelism = tableEnv.getConfig.getConf + .getInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM) + val transformation = translateToPlanInternal( + inputTransformation, + tableEnv.streamEnv, + tableEnv.config, + tableEnv.getRelBuilder) + transformation.setParallelism(defaultParallelism) + transformation + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortAggregate.scala index 63b07589bb6c7..208ad9ec1e574 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortAggregate.scala @@ -44,7 +44,7 @@ class BatchExecSortAggregate( inputRel: RelNode, outputRowType: RelDataType, inputRowType: RelDataType, - aggInputRowType: RelDataType, + val aggInputRowType: RelDataType, grouping: Array[Int], auxGrouping: Array[Int], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecLookupJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecLookupJoin.scala new file mode 100644 index 0000000000000..3c026432d4da6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecLookupJoin.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes.physical.stream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.{JoinInfo, JoinRelType} +import org.apache.calcite.rex.RexProgram +import org.apache.flink.streaming.api.transformations.StreamTransformation +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.nodes.common.CommonLookupJoin +import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} +import org.apache.flink.table.sources.TableSource + +import java.util + +import scala.collection.JavaConversions._ + +/** + * Stream physical RelNode for temporal table join. + */ +class StreamExecLookupJoin( + cluster: RelOptCluster, + traitSet: RelTraitSet, + input: RelNode, + tableSource: TableSource[_], + tableRowType: RelDataType, + tableCalcProgram: Option[RexProgram], + joinInfo: JoinInfo, + joinType: JoinRelType) + extends CommonLookupJoin( + cluster, + traitSet, + input, + tableSource, + tableRowType, + tableCalcProgram, + joinInfo, + joinType) + with StreamPhysicalRel + with StreamExecNode[BaseRow] { + + override def producesUpdates: Boolean = false + + override def needsUpdatesAsRetraction(input: RelNode): Boolean = false + + override def consumesRetractions: Boolean = false + + override def producesRetractions: Boolean = false + + override def requireWatermark: Boolean = false + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new StreamExecLookupJoin( + cluster, + traitSet, + inputs.get(0), + tableSource, + tableRowType, + tableCalcProgram, + joinInfo, + joinType) + } + + //~ ExecNode methods ----------------------------------------------------------- + + override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] = { + List(getInput.asInstanceOf[ExecNode[StreamTableEnvironment, _]]) + } + + override def replaceInputNode( + ordinalInParent: Int, + newInputNode: ExecNode[StreamTableEnvironment, _]): Unit = { + replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode]) + } + + override protected def translateToPlanInternal( + tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = { + val inputTransformation = getInputNodes.get(0).translateToPlan(tableEnv) + .asInstanceOf[StreamTransformation[BaseRow]] + translateToPlanInternal( + inputTransformation, + tableEnv.execEnv, + tableEnv.config, + tableEnv.getRelBuilder) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkBatchProgram.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkBatchProgram.scala index d91779b6881ad..5f03e5d6699ac 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkBatchProgram.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkBatchProgram.scala @@ -29,11 +29,13 @@ import org.apache.calcite.plan.hep.HepMatchOrder */ object FlinkBatchProgram { val SUBQUERY_REWRITE = "subquery_rewrite" + val CORRELATE_REWRITE = "correlate_rewrite" val DECORRELATE = "decorrelate" val DEFAULT_REWRITE = "default_rewrite" val PREDICATE_PUSHDOWN = "predicate_pushdown" val WINDOW = "window" val LOGICAL = "logical" + val LOGICAL_REWRITE = "logical_rewrite" val PHYSICAL = "physical" def buildProgram(config: Configuration): FlinkChainedProgram[BatchOptimizeContext] = { @@ -68,6 +70,24 @@ object FlinkBatchProgram { .build() ) + // rewrite special temporal join plan + chainedProgram.addLast( + CORRELATE_REWRITE, + FlinkGroupProgramBuilder.newBuilder[BatchOptimizeContext] + .addProgram( + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(FlinkBatchRuleSets.EXPAND_PLAN_RULES) + .build(), "convert correlate to temporal table join") + .addProgram( + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(FlinkBatchRuleSets.POST_EXPAND_CLEAN_UP_RULES) + .build(), "convert enumerable table scan") + .build()) + // query decorrelation chainedProgram.addLast(DECORRELATE, new FlinkDecorrelateProgram) @@ -124,6 +144,15 @@ object FlinkBatchProgram { .setRequiredOutputTraits(Array(FlinkConventions.LOGICAL)) .build()) + // logical rewrite + chainedProgram.addLast( + LOGICAL_REWRITE, + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(FlinkBatchRuleSets.LOGICAL_REWRITE) + .build()) + // optimize the physical plan chainedProgram.addLast( PHYSICAL, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkStreamProgram.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkStreamProgram.scala index 46e2a082b1d83..64cf5c903f6e6 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkStreamProgram.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/optimize/program/FlinkStreamProgram.scala @@ -18,18 +18,18 @@ package org.apache.flink.table.plan.optimize.program +import org.apache.calcite.plan.hep.HepMatchOrder import org.apache.flink.configuration.Configuration import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.rules.FlinkStreamRuleSets -import org.apache.calcite.plan.hep.HepMatchOrder - /** * Defines a sequence of programs to optimize for stream table plan. */ object FlinkStreamProgram { val SUBQUERY_REWRITE = "subquery_rewrite" + val CORRELATE_REWRITE = "correlate_rewrite" val DECORRELATE = "decorrelate" val TIME_INDICATOR = "time_indicator" val DEFAULT_REWRITE = "default_rewrite" @@ -70,6 +70,24 @@ object FlinkStreamProgram { .build(), "convert table references after sub-queries removed") .build()) + // rewrite special temporal join plan + chainedProgram.addLast( + CORRELATE_REWRITE, + FlinkGroupProgramBuilder.newBuilder[StreamOptimizeContext] + .addProgram( + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(FlinkStreamRuleSets.EXPAND_PLAN_RULES) + .build(), "convert correlate to temporal table join") + .addProgram( + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(FlinkStreamRuleSets.POST_EXPAND_CLEAN_UP_RULES) + .build(), "convert enumerable table scan") + .build()) + // query decorrelation chainedProgram.addLast(DECORRELATE, new FlinkDecorrelateProgram) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala index 22f13cc8992c4..d0c7c9e91eb18 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala @@ -47,6 +47,17 @@ object FlinkBatchRuleSets { SubQueryRemoveRule.JOIN ) + /** + * Expand plan by replacing references to tables into a proper plan sub trees. Those rules + * can create new plan nodes. + */ + val EXPAND_PLAN_RULES: RuleSet = RuleSets.ofList( + LogicalCorrelateToTemporalTableJoinRule.INSTANCE, + TableScanRule.INSTANCE) + + val POST_EXPAND_CLEAN_UP_RULES: RuleSet = RuleSets.ofList( + EnumerableToLogicalTableScan.INSTANCE) + /** * Convert table references before query decorrelation. */ @@ -112,11 +123,21 @@ object FlinkBatchRuleSets { FilterMergeRule.INSTANCE ) + /** + * Ruleset to simplify expressions + */ + private val PREDICATE_SIMPLIFY_EXPRESSION_RULES: RuleSet = RuleSets.ofList( + // TODO: add filter simply and join condition simplify rules + JoinPushExpressionsRule.INSTANCE + ) + /** * RuleSet to do predicate pushdown */ val FILTER_PREPARE_RULES: RuleSet = RuleSets.ofList(( FILTER_RULES.asScala + // simplify expressions + ++ PREDICATE_SIMPLIFY_EXPRESSION_RULES.asScala // reduce expressions in filters and joins ++ REDUCE_EXPRESSION_RULES.asScala ).asJava @@ -188,7 +209,7 @@ object FlinkBatchRuleSets { // remove aggregation if it does not aggregate and input is already distinct AggregateRemoveRule.INSTANCE, // push aggregate through join - AggregateJoinTransposeRule.EXTENDED, + FlinkAggregateJoinTransposeRule.EXTENDED, // aggregate union rule AggregateUnionAggregateRule.INSTANCE, // expand distinct aggregate to normal aggregate with groupby @@ -233,6 +254,7 @@ object FlinkBatchRuleSets { FlinkLogicalExpand.CONVERTER, FlinkLogicalRank.CONVERTER, FlinkLogicalWindowAggregate.CONVERTER, + FlinkLogicalSnapshot.CONVERTER, FlinkLogicalSink.CONVERTER ) @@ -247,6 +269,12 @@ object FlinkBatchRuleSets { LOGICAL_CONVERTERS.asScala ).asJava) + val LOGICAL_REWRITE: RuleSet = RuleSets.ofList( + // transpose calc past snapshot + CalcSnapshotTransposeRule.INSTANCE, + // merge calc after calc transpose + CalcMergeRule.INSTANCE) + /** * RuleSet to do physical optimize for batch */ @@ -271,6 +299,8 @@ object FlinkBatchRuleSets { BatchExecCorrelateRule.INSTANCE, BatchExecOverWindowAggRule.INSTANCE, BatchExecWindowAggregateRule.INSTANCE, + BatchExecLookupJoinRule.SNAPSHOT_ON_TABLESCAN, + BatchExecLookupJoinRule.SNAPSHOT_ON_CALC_TABLESCAN, BatchExecSinkRule.INSTANCE ) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala index d6cdba87b9558..4b5032a0b1769 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala @@ -19,10 +19,9 @@ package org.apache.flink.table.plan.rules import org.apache.flink.table.plan.nodes.logical._ -import org.apache.flink.table.plan.rules.logical._ +import org.apache.flink.table.plan.rules.logical.{CalcSnapshotTransposeRule, _} import org.apache.flink.table.plan.rules.physical.FlinkExpandConversionRule import org.apache.flink.table.plan.rules.physical.stream._ - import org.apache.calcite.rel.core.RelFactories import org.apache.calcite.rel.logical.{LogicalIntersect, LogicalMinus, LogicalUnion} import org.apache.calcite.rel.rules._ @@ -47,6 +46,17 @@ object FlinkStreamRuleSets { SubQueryRemoveRule.JOIN ) + /** + * Expand plan by replacing references to tables into a proper plan sub trees. Those rules + * can create new plan nodes. + */ + val EXPAND_PLAN_RULES: RuleSet = RuleSets.ofList( + LogicalCorrelateToTemporalTableJoinRule.INSTANCE, + TableScanRule.INSTANCE) + + val POST_EXPAND_CLEAN_UP_RULES: RuleSet = RuleSets.ofList( + EnumerableToLogicalTableScan.INSTANCE) + /** * Convert table references before query decorrelation. */ @@ -116,11 +126,21 @@ object FlinkStreamRuleSets { FilterMergeRule.INSTANCE ) + /** + * Ruleset to simplify expressions + */ + private val PREDICATE_SIMPLIFY_EXPRESSION_RULES: RuleSet = RuleSets.ofList( + // TODO: add filter simply and join condition simplify rules + FlinkJoinPushExpressionsRule.INSTANCE + ) + /** * RuleSet to do predicate pushdown */ val FILTER_PREPARE_RULES: RuleSet = RuleSets.ofList(( FILTER_RULES.asScala + // simplify expressions + ++ PREDICATE_SIMPLIFY_EXPRESSION_RULES.asScala // reduce expressions in filters and joins ++ REDUCE_EXPRESSION_RULES.asScala ).asJava) @@ -172,7 +192,7 @@ object FlinkStreamRuleSets { SortProjectTransposeRule.INSTANCE, // join rules - JoinPushExpressionsRule.INSTANCE, + FlinkJoinPushExpressionsRule.INSTANCE, // remove union with only a single child UnionEliminatorRule.INSTANCE, @@ -222,6 +242,7 @@ object FlinkStreamRuleSets { FlinkLogicalExpand.CONVERTER, FlinkLogicalWatermarkAssigner.CONVERTER, FlinkLogicalWindowAggregate.CONVERTER, + FlinkLogicalSnapshot.CONVERTER, FlinkLogicalSink.CONVERTER ) @@ -244,6 +265,8 @@ object FlinkStreamRuleSets { FlinkLogicalRankRule.INSTANCE, // split distinct aggregate to reduce data skew SplitAggregateRule.INSTANCE, + // transpose calc past snapshot + CalcSnapshotTransposeRule.INSTANCE, // merge calc after calc transpose FlinkCalcMergeRule.INSTANCE ) @@ -271,6 +294,8 @@ object FlinkStreamRuleSets { StreamExecJoinRule.INSTANCE, StreamExecWindowJoinRule.INSTANCE, StreamExecCorrelateRule.INSTANCE, + StreamExecLookupJoinRule.SNAPSHOT_ON_TABLESCAN, + StreamExecLookupJoinRule.SNAPSHOT_ON_CALC_TABLESCAN, StreamExecSinkRule.INSTANCE ) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/CalcSnapshotTransposeRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/CalcSnapshotTransposeRule.scala new file mode 100644 index 0000000000000..54455ce90f155 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/CalcSnapshotTransposeRule.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rex.RexOver +import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalSnapshot} + +/** + * Transpose [[FlinkLogicalCalc]] past into [[FlinkLogicalSnapshot]]. + */ +class CalcSnapshotTransposeRule extends RelOptRule( + operand(classOf[FlinkLogicalCalc], + operand(classOf[FlinkLogicalSnapshot], any())), + "CalcSnapshotTransposeRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val calc = call.rel[FlinkLogicalCalc](0) + // Don't push a calc which contains windowed aggregates into a snapshot for now. + !RexOver.containsOver(calc.getProgram) + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val calc = call.rel[FlinkLogicalCalc](0) + val snapshot = call.rel[FlinkLogicalSnapshot](1) + val newClac = calc.copy(calc.getTraitSet, snapshot.getInputs) + val newSnapshot = snapshot.copy(snapshot.getTraitSet, newClac, snapshot.getPeriod) + call.transformTo(newSnapshot) + } +} + +object CalcSnapshotTransposeRule { + val INSTANCE = new CalcSnapshotTransposeRule +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.scala new file mode 100644 index 0000000000000..6de7a8107d56c --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRuleCall +import org.apache.calcite.plan.hep.HepRelVertex +import org.apache.calcite.plan.volcano.RelSubset +import org.apache.calcite.rel.{RelNode, SingleRel} +import org.apache.calcite.rel.core.{Aggregate, Join, RelFactories} +import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalJoin, LogicalSnapshot} +import org.apache.calcite.rel.rules.AggregateJoinTransposeRule +import org.apache.calcite.tools.RelBuilderFactory + +/** + * Flink's [[AggregateJoinTransposeRule]] which does not match temporal join + * since lookup table source doesn't support aggregate. + */ +class FlinkAggregateJoinTransposeRule( + aggregateClass: Class[_ <: Aggregate], + joinClass: Class[_ <: Join], + factory: RelBuilderFactory, + allowFunctions: Boolean) + extends AggregateJoinTransposeRule(aggregateClass, joinClass, factory, allowFunctions) { + + override def matches(call: RelOptRuleCall): Boolean = { + val join: Join = call.rel(1) + if (containsSnapshot(join.getRight)) { + // avoid push aggregates through temporal join + false + } else { + super.matches(call) + } + } + + private def containsSnapshot(relNode: RelNode): Boolean = { + val original = relNode match { + case r: RelSubset => r.getOriginal + case r: HepRelVertex => r.getCurrentRel + case _ => relNode + } + original match { + case _: LogicalSnapshot => true + case r: SingleRel => containsSnapshot(r.getInput) + case _ => false + } + } +} + +object FlinkAggregateJoinTransposeRule { + + /** Extended instance of the rule that can push down aggregate functions. */ + val EXTENDED = new FlinkAggregateJoinTransposeRule( + classOf[LogicalAggregate], + classOf[LogicalJoin], + RelFactories.LOGICAL_BUILDER, + true) +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala new file mode 100644 index 0000000000000..0f91a8a610316 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, operand, some} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.logical.{LogicalCorrelate, LogicalFilter, LogicalSnapshot} + +/** + * The initial temporal table join is a Correlate, rewrite it into a Join to make the + * join condition push-down into the Join + */ +class LogicalCorrelateToTemporalTableJoinRule + extends RelOptRule( + operand(classOf[LogicalFilter], + operand(classOf[LogicalCorrelate], some( + operand(classOf[RelNode], any()), + operand(classOf[LogicalSnapshot], any())))), + "LogicalCorrelateToTemporalTableJoinRule") { + + override def onMatch(call: RelOptRuleCall): Unit = { + val filterOnCorrelate: LogicalFilter = call.rel(0) + val correlate: LogicalCorrelate = call.rel(1) + val leftNode: RelNode = call.rel(2) + val snapshot: LogicalSnapshot = call.rel[LogicalSnapshot](3) + + val builder = call.builder() + builder.push(leftNode) + builder.push(snapshot) + builder.join( + correlate.getJoinType.toJoinType, + filterOnCorrelate.getCondition) + + call.transformTo(builder.build()) + } + +} + +object LogicalCorrelateToTemporalTableJoinRule { + val INSTANCE = new LogicalCorrelateToTemporalTableJoinRule +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecLookupJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecLookupJoinRule.scala new file mode 100644 index 0000000000000..25db4814f089a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/BatchExecLookupJoinRule.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.physical.batch + +import org.apache.calcite.plan.RelOptRule +import org.apache.calcite.rex.RexProgram +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.common.CommonLookupJoin +import org.apache.flink.table.plan.nodes.logical._ +import org.apache.flink.table.plan.nodes.physical.batch.BatchExecLookupJoin +import org.apache.flink.table.plan.rules.physical.common.{BaseSnapshotOnCalcTableScanRule, BaseSnapshotOnTableScanRule} +import org.apache.flink.table.sources.TableSource + +/** + * Rules that convert [[FlinkLogicalJoin]] on a [[FlinkLogicalSnapshot]] + * into [[BatchExecLookupJoin]]. + * + * There are 2 conditions for this rule: + * 1. the root parent of [[FlinkLogicalSnapshot]] should be a TableSource which implements + * [[org.apache.flink.table.sources.LookupableTableSource]]. + * 2. the period of [[FlinkLogicalSnapshot]] must be left table's proctime attribute. + */ +object BatchExecLookupJoinRule { + val SNAPSHOT_ON_TABLESCAN: RelOptRule = new SnapshotOnTableScanRule + val SNAPSHOT_ON_CALC_TABLESCAN: RelOptRule = new SnapshotOnCalcTableScanRule + + class SnapshotOnTableScanRule + extends BaseSnapshotOnTableScanRule("BatchExecSnapshotOnTableScanRule") { + + override protected def transform( + join: FlinkLogicalJoin, + input: FlinkLogicalRel, + tableSource: TableSource[_], + calcProgram: Option[RexProgram]): CommonLookupJoin = { + doTransform(join, input, tableSource, calcProgram) + } + } + + class SnapshotOnCalcTableScanRule + extends BaseSnapshotOnCalcTableScanRule("BatchExecSnapshotOnCalcTableScanRule") { + + override protected def transform( + join: FlinkLogicalJoin, + input: FlinkLogicalRel, + tableSource: TableSource[_], + calcProgram: Option[RexProgram]): CommonLookupJoin = { + doTransform(join, input, tableSource, calcProgram) + } + + } + + private def doTransform( + join: FlinkLogicalJoin, + input: FlinkLogicalRel, + tableSource: TableSource[_], + calcProgram: Option[RexProgram]): BatchExecLookupJoin = { + val joinInfo = join.analyzeCondition + val cluster = join.getCluster + val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val tableRowType = typeFactory.buildLogicalRowType( + tableSource.getTableSchema, isStreaming = Option.apply(false)) + + val providedTrait = join.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) + val requiredTrait = input.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) + val convInput = RelOptRule.convert(input, requiredTrait) + new BatchExecLookupJoin( + cluster, + providedTrait, + convInput, + tableSource, + tableRowType, + calcProgram, + joinInfo, + join.getJoinType) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/common/CommonLookupJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/common/CommonLookupJoinRule.scala new file mode 100644 index 0000000000000..fc56cf0ec910c --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/common/CommonLookupJoinRule.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.physical.common + +import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.core.TableScan +import org.apache.calcite.rex.{RexCorrelVariable, RexFieldAccess, RexProgram} +import org.apache.flink.table.api.TableException +import org.apache.flink.table.plan.nodes.common.CommonLookupJoin +import org.apache.flink.table.plan.nodes.logical._ +import org.apache.flink.table.plan.nodes.physical.PhysicalTableSourceScan +import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType +import org.apache.flink.table.sources.{LookupableTableSource, TableSource} + +/** + * Base implementation for both + * [[org.apache.flink.table.plan.rules.physical.batch.BatchExecLookupJoinRule]] and + * [[org.apache.flink.table.plan.rules.physical.stream.StreamExecLookupJoinRule]]. + */ +trait CommonLookupJoinRule { + + protected def matches( + join: FlinkLogicalJoin, + snapshot: FlinkLogicalSnapshot, + tableScan: TableScan): Boolean = { + // TODO: shouldn't match temporal UDTF join + if (findTableSource(tableScan).isEmpty) { + throw new TableException( + "Temporal table join only support join on a LookupableTableSource " + + "not on a DataStream or an intermediate query") + } + // period specification check + snapshot.getPeriod match { + // it's left table's field, pass + case r: RexFieldAccess if r.getReferenceExpr.isInstanceOf[RexCorrelVariable] => + case _ => + throw new TableException("Temporal table join currently only supports " + + "'FOR SYSTEM_TIME AS OF' left table's proctime field, doesn't support 'PROCTIME()'") + } + snapshot.getPeriod.getType match { + // TODO: support to translate rowtime temporal join to TemporalTableJoin in the future + case t: TimeIndicatorRelDataType if !t.isEventTime => // pass + case _ => + throw new TableException("Temporal table join currently only supports " + + "'FOR SYSTEM_TIME AS OF' left table's proctime field, doesn't support 'PROCTIME()'") + } + // currently temporal table join only support LookupableTableSource + isLookupableTableSource(tableScan) + } + + protected def findTableSource(relNode: RelNode): Option[TableSource[_]] = { + relNode match { + case logicalScan: FlinkLogicalTableSourceScan => Some(logicalScan.tableSource) + case physicalScan: PhysicalTableSourceScan => Some(physicalScan.tableSource) + // TODO: find TableSource in FlinkLogicalIntermediateTableScan + case _ => None + } + } + + protected def isLookupableTableSource(relNode: RelNode): Boolean = { + relNode match { + case logicalScan: FlinkLogicalTableSourceScan => + logicalScan.tableSource.isInstanceOf[LookupableTableSource[_]] + case physicalScan: PhysicalTableSourceScan => + physicalScan.tableSource.isInstanceOf[LookupableTableSource[_]] + // TODO: find TableSource in FlinkLogicalIntermediateTableScan + case _ => false + } + } + + protected def transform( + join: FlinkLogicalJoin, + input: FlinkLogicalRel, + tableSource: TableSource[_], + calcProgram: Option[RexProgram]): CommonLookupJoin +} + +abstract class BaseSnapshotOnTableScanRule(description: String) + extends RelOptRule( + operand(classOf[FlinkLogicalJoin], + operand(classOf[FlinkLogicalRel], any()), + operand(classOf[FlinkLogicalSnapshot], + operand(classOf[TableScan], any()))), + description) + with CommonLookupJoinRule { + + override def matches(call: RelOptRuleCall): Boolean = { + val join = call.rel[FlinkLogicalJoin](0) + val snapshot = call.rel[FlinkLogicalSnapshot](2) + val tableScan = call.rel[TableScan](3) + matches(join, snapshot, tableScan) + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val join = call.rel[FlinkLogicalJoin](0) + val input = call.rel[FlinkLogicalRel](1) + val tableScan = call.rel[RelNode](3) + val tableSource = findTableSource(tableScan).orNull + + val temporalJoin = transform(join, input, tableSource, None) + call.transformTo(temporalJoin) + } + +} + +abstract class BaseSnapshotOnCalcTableScanRule(description: String) + extends RelOptRule( + operand(classOf[FlinkLogicalJoin], + operand(classOf[FlinkLogicalRel], any()), + operand(classOf[FlinkLogicalSnapshot], + operand(classOf[FlinkLogicalCalc], + operand(classOf[TableScan], any())))), + description) + with CommonLookupJoinRule { + + override def matches(call: RelOptRuleCall): Boolean = { + val join = call.rel[FlinkLogicalJoin](0) + val snapshot = call.rel[FlinkLogicalSnapshot](2) + val tableScan = call.rel[TableScan](4) + matches(join, snapshot, tableScan) + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val join = call.rel[FlinkLogicalJoin](0) + val input = call.rel[FlinkLogicalRel](1) + val calc = call.rel[FlinkLogicalCalc](3) + val tableScan = call.rel[RelNode](4) + val tableSource = findTableSource(tableScan).orNull + + val temporalJoin = transform( + join, input, tableSource, Some(calc.getProgram)) + call.transformTo(temporalJoin) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala index 0f12ae88ebf67..cf6514ab45471 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala @@ -22,7 +22,7 @@ import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.{FlinkContext, FlinkTypeFactory} import org.apache.flink.table.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalJoin, FlinkLogicalRel} +import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalJoin, FlinkLogicalRel, FlinkLogicalSnapshot} import org.apache.flink.table.plan.nodes.physical.stream.StreamExecJoin import org.apache.flink.table.plan.util.WindowJoinUtil @@ -51,11 +51,21 @@ class StreamExecJoinRule // SEMI/ANTI join always converts to StreamExecJoin now return true } + val left: FlinkLogicalRel = call.rel(1).asInstanceOf[FlinkLogicalRel] + val right: FlinkLogicalRel = call.rel(2).asInstanceOf[FlinkLogicalRel] + val tableConfig = call.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig + val joinRowType = join.getRowType - // TODO check LHS or RHS are FlinkLogicalSnapshot + if (left.isInstanceOf[FlinkLogicalSnapshot]) { + throw new TableException( + "Temporal table join only support apply FOR SYSTEM_TIME AS OF on the right table.") + } + + // this rule shouldn't match temporal table join + if (right.isInstanceOf[FlinkLogicalSnapshot]) { + return false + } - val joinRowType = join.getRowType - val tableConfig = call.getPlanner.getContext.asInstanceOf[FlinkContext].getTableConfig val (windowBounds, remainingPreds) = WindowJoinUtil.extractWindowBoundsFromPredicate( join.getCondition, join.getLeft.getRowType.getFieldCount, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecLookupJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecLookupJoinRule.scala new file mode 100644 index 0000000000000..e51482a5d5196 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecLookupJoinRule.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.physical.stream + +import org.apache.calcite.plan.RelOptRule +import org.apache.calcite.rex.{RexNode, RexProgram} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.common.CommonLookupJoin +import org.apache.flink.table.plan.nodes.logical._ +import org.apache.flink.table.plan.nodes.physical.stream.StreamExecLookupJoin +import org.apache.flink.table.plan.rules.physical.common.{BaseSnapshotOnCalcTableScanRule, BaseSnapshotOnTableScanRule} +import org.apache.flink.table.sources.TableSource + +/** + * Rules that convert [[FlinkLogicalJoin]] on a [[FlinkLogicalSnapshot]] + * into [[StreamExecLookupJoin]] + * + * There are 2 conditions for this rule: + * 1. the root parent of [[FlinkLogicalSnapshot]] should be a TableSource which implements + * [[org.apache.flink.table.sources.LookupableTableSource]]. + * 2. the period of [[FlinkLogicalSnapshot]] must be left table's proctime attribute. + */ +object StreamExecLookupJoinRule { + val SNAPSHOT_ON_TABLESCAN: RelOptRule = new SnapshotOnTableScanRule + val SNAPSHOT_ON_CALC_TABLESCAN: RelOptRule = new SnapshotOnCalcTableScanRule + + class SnapshotOnTableScanRule + extends BaseSnapshotOnTableScanRule("StreamExecSnapshotOnTableScanRule") { + + override protected def transform( + join: FlinkLogicalJoin, + input: FlinkLogicalRel, + tableSource: TableSource[_], + calcProgram: Option[RexProgram]): CommonLookupJoin = { + doTransform(join, input, tableSource, calcProgram) + } + } + + class SnapshotOnCalcTableScanRule + extends BaseSnapshotOnCalcTableScanRule("StreamExecSnapshotOnCalcTableScanRule") { + + override protected def transform( + join: FlinkLogicalJoin, + input: FlinkLogicalRel, + tableSource: TableSource[_], + calcProgram: Option[RexProgram]): CommonLookupJoin = { + doTransform(join, input, tableSource, calcProgram) + } + } + + private def doTransform( + join: FlinkLogicalJoin, + input: FlinkLogicalRel, + tableSource: TableSource[_], + calcProgram: Option[RexProgram]): StreamExecLookupJoin = { + + val joinInfo = join.analyzeCondition + + val cluster = join.getCluster + val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val tableRowType = typeFactory.buildLogicalRowType( + tableSource.getTableSchema, isStreaming = Option.apply(true)) + + val providedTrait = join.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL) + var requiredTrait = input.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL) + + val convInput = RelOptRule.convert(input, requiredTrait) + new StreamExecLookupJoin( + cluster, + providedTrait, + convInput, + tableSource, + tableRowType, + calcProgram, + joinInfo, + join.getJoinType) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala index 9c0c3df00fee1..42a19044c737f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala @@ -18,14 +18,14 @@ package org.apache.flink.table.plan.schema +import org.apache.calcite.schema.TemporalTable import org.apache.flink.table.plan.stats.FlinkStatistic - import org.apache.calcite.schema.impl.AbstractTable /** * Base class for flink table. */ -abstract class FlinkTable extends AbstractTable { +abstract class FlinkTable extends AbstractTable with TemporalTable { /** * Restrict return type of statistic to FlinkStatistic. @@ -40,4 +40,13 @@ abstract class FlinkTable extends AbstractTable { */ def copy(statistic: FlinkStatistic): FlinkTable + /** + * Currently we do not need this, so we hard code it as default. + */ + override def getSysStartFieldName: String = "sys_start" + + /** + * Currently we do not need this, so we hard code it as default. + */ + override def getSysEndFieldName: String = "sys_end" } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelOptUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelOptUtil.scala index f6a893508f9b7..d5c7487544949 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelOptUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/FlinkRelOptUtil.scala @@ -18,9 +18,8 @@ package org.apache.flink.table.plan.util import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig} -import org.apache.flink.table.calcite.{FlinkContext, FlinkPlannerImpl} +import org.apache.flink.table.calcite.{FlinkContext, FlinkPlannerImpl, FlinkTypeFactory} import org.apache.flink.table.{JBoolean, JByte, JDouble, JFloat, JLong, JShort} - import com.google.common.collect.{ImmutableList, Lists} import org.apache.calcite.config.NullCollation import org.apache.calcite.plan.RelOptUtil.InputFinder @@ -516,18 +515,22 @@ object FlinkRelOptUtil { } case OR | INPUT_REF | LITERAL | NOT => node case _ => - val bits = RelOptUtil.InputFinder.bits(node) - val mid = leftCount + extraLeftExprs.size - Side.of(bits, mid) match { - case Side.LEFT => - fix(extraRightExprs, mid, mid + 1) - extraLeftExprs.add(node) - new RexInputRef(mid, node.getType) - case Side.RIGHT => - val index2 = mid + rightCount + extraRightExprs.size - extraRightExprs.add(node) - new RexInputRef(index2, node.getType) - case _ => node + if (node.accept(new TimeIndicatorExprFinder)) { + node + } else { + val bits = RelOptUtil.InputFinder.bits(node) + val mid = leftCount + extraLeftExprs.size + Side.of(bits, mid) match { + case Side.LEFT => + fix(extraRightExprs, mid, mid + 1) + extraLeftExprs.add(node) + new RexInputRef(mid, node.getType) + case Side.RIGHT => + val index2 = mid + rightCount + extraRightExprs.size + extraRightExprs.add(node) + new RexInputRef(index2, node.getType) + case _ => node + } } } @@ -709,4 +712,13 @@ object FlinkRelOptUtil { } } + /** + * An RexVisitor to find whether this is a call on a time indicator field. + */ + class TimeIndicatorExprFinder extends RexVisitorImpl[Boolean](true) { + override def visitInputRef(inputRef: RexInputRef): Boolean = { + FlinkTypeFactory.isTimeIndicatorType(inputRef.getType) + } + } + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/LookupJoinUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/LookupJoinUtil.scala new file mode 100644 index 0000000000000..cd25888b2dee4 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/LookupJoinUtil.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import org.apache.calcite.rex.RexLiteral +import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.sources.{DefinedIndexes, DefinedPrimaryKey, TableIndex, TableSource} + +import scala.collection.JavaConverters._ + +/** + * Utilities for temporal table join + */ +object LookupJoinUtil { + + /** + * A [[LookupKey]] is a field used as equal condition when querying content from dimension table + */ + sealed trait LookupKey + + /** + * A [[LookupKey]] whose value is constant. + * @param dataType the field type in TableSource + * @param literal the literal value + */ + case class ConstantLookupKey(dataType: InternalType, literal: RexLiteral) extends LookupKey + + /** + * A [[LookupKey]] whose value comes from left table field. + * @param index the index of the field in left table + */ + case class FieldRefLookupKey(index: Int) extends LookupKey + + /** + * Gets [[TableIndex]]s from a [[TableSource]]. This will combine primary key information + * of [[DefinedPrimaryKey]] and indexes information of [[DefinedIndexes]]. + */ + def getTableIndexes(table: TableSource[_]): Array[TableIndex] = { + val indexes: Array[TableIndex] = table match { + case t: DefinedIndexes if t.getIndexes != null => t.getIndexes.asScala.toArray + case _ => Array() + } + + // add primary key into index list because primary key is an index too + table match { + case t: DefinedPrimaryKey => + val primaryKey = t.getPrimaryKeyColumns + if (primaryKey != null && !primaryKey.isEmpty) { + val primaryKeyIndex = TableIndex.builder() + .uniqueIndex() + .indexedColumns(primaryKey) + .build() + indexes ++ Array(primaryKeyIndex) + } else { + indexes + } + case _ => indexes + } + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RelExplainUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RelExplainUtil.scala index 7aae9dfd3bd7a..cdf3560c796f9 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RelExplainUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RelExplainUtil.scala @@ -588,6 +588,20 @@ object RelExplainUtil { s"Calc($name)" } + def conditionToString( + calcProgram: RexProgram, + f: (RexNode, List[String], Option[List[RexNode]]) => String): String = { + val cond = calcProgram.getCondition + val inputFieldNames = calcProgram.getInputRowType.getFieldNames.toList + val localExprs = calcProgram.getExprList.toList + + if (cond != null) { + f(cond, inputFieldNames, Some(localExprs)) + } else { + "" + } + } + def selectionToString( calcProgram: RexProgram, expression: (RexNode, List[String], Option[List[RexNode]], ExpressionFormat) => String, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/runtime/collector/TableFunctionCollector.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/runtime/collector/TableFunctionCollector.scala deleted file mode 100644 index a8a7905b2b9c3..0000000000000 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/runtime/collector/TableFunctionCollector.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.runtime.collector - -import org.apache.flink.table.functions.TableFunction -import org.apache.flink.util.Collector - -/** - * The basic implementation of collector for [[TableFunction]]. - */ -abstract class TableFunctionCollector[T] extends Collector[T] { - - private var input: Any = _ - private var collector: Collector[_] = _ - private var collected: Boolean = _ - - /** - * Sets the input row from left table, - * which will be used to cross join with the result of table function. - */ - def setInput(input: Any): Unit = { - this.input = input - } - - /** - * Gets the input value from left table, - * which will be used to cross join with the result of table function. - */ - def getInput: Any = { - input - } - - /** - * Sets the current collector, which used to emit the final row. - */ - def setCollector(collector: Collector[_]): Unit = { - this.collector = collector - } - - /** - * Gets the internal collector which used to emit the final row. - */ - def getCollector: Collector[_] = { - this.collector - } - - /** - * Resets the flag to indicate whether [[collect(T)]] has been called. - */ - def reset(): Unit = { - collected = false - } - - /** - * Whether [[collect(T)]] has been called. - * - * @return True if [[collect(T)]] has been called. - */ - def isCollected: Boolean = collected - - override def collect(record: T): Unit = { - collected = true - } -} - - diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/api/batch/ExplainTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/api/batch/ExplainTest.xml index ef6eedd06c129..bdd2453279d27 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/api/batch/ExplainTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/api/batch/ExplainTest.xml @@ -20,10 +20,10 @@ limitations under the License. diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/api/stream/ExplainTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/api/stream/ExplainTest.xml index cf6797a0c8148..45ea29434af8d 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/api/stream/ExplainTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/api/stream/ExplainTest.xml @@ -23,7 +23,7 @@ LogicalProject(a=[$0], b=[$1], c=[$2]) +- LogicalTableScan(table=[[DataStreamTable]]) == Optimized Logical Plan == -DataStreamScan(table=[[_DataStreamTable_0]], fields=[a, b, c]) +DataStreamScan(table=[[_DataStreamTable_1]], fields=[a, b, c]) ]]> @@ -32,10 +32,10 @@ DataStreamScan(table=[[_DataStreamTable_0]], fields=[a, b, c]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml index cabfaec0aeab3..f035f65cbbc55 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml @@ -38,15 +38,17 @@ LogicalProject(c=[$0], a=[$1], b=[$2], c0=[$3], a0=[$4], b0=[$5]) (a, 1)]) - +- Reused(reference_id=[1]) +Calc(select=[c, a, b, c0, a1, b0]) ++- NestedLoopJoin(joinType=[InnerJoin], where=[=(a0, b0)], select=[c, a, b, a0, c0, a1, b0], build=[right]) + :- Exchange(distribution=[any], exchange_mode=[BATCH]) + : +- Calc(select=[c, a, b, CAST(a) AS a0]) + : +- HashAggregate(isMerge=[true], groupBy=[c], select=[c, Final_SUM(sum$0) AS a, Final_SUM(sum$1) AS b], reuse_id=[1]) + : +- Exchange(distribution=[hash[c]]) + : +- LocalHashAggregate(groupBy=[c], select=[c, Partial_SUM(a) AS sum$0, Partial_SUM(b) AS sum$1]) + : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + +- Exchange(distribution=[broadcast]) + +- Calc(select=[c, a, b], where=[>(a, 1)]) + +- Reused(reference_id=[1]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml index 27400aa8c8304..302bc021ffe8c 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml @@ -239,11 +239,11 @@ LogicalProject(c=[$0], e=[$1], avg_b=[$2], sum_b=[$3], psum=[$4], nsum=[$5], avg (w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, w1$o0 AS rn, c, e], where=[AND(<>(c, _UTF-16LE'':VARCHAR(65536) CHARACTER SET "UTF-16LE"), >(-(sum_b, /(CAST(CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0)), 3))]) : : +- OverAggregate(partitionBy=[c, e], orderBy=[], window#0=[COUNT(sum_b) AS w0$o0, $SUM0(sum_b) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], window#1=[RANK(*) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[c, e, sum_b, w0$o0, w0$o1, w1$o0], reuse_id=[1]) : : +- Sort(orderBy=[c ASC, e ASC], reuse_id=[2]) @@ -259,11 +259,11 @@ Calc(select=[c, e, avg_b, sum_b, sum_b0 AS psum, sum_b1 AS nsum, avg_b0 AS avg_b : : +- Exchange(distribution=[hash[d]]) : : +- Calc(select=[d, e], where=[>(e, 10)]) : : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) - : +- Exchange(distribution=[hash[c, e]], exchange_mode=[BATCH]) - : +- Calc(select=[sum_b, /(CAST(CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, w1$o0 AS rn, c, e]) + : +- Exchange(distribution=[hash[c, e, $f5]], exchange_mode=[BATCH]) + : +- Calc(select=[sum_b, /(CAST(CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, c, e, +(w1$o0, 1) AS $f5]) : +- Reused(reference_id=[1]) - +- Exchange(distribution=[hash[c, e]], exchange_mode=[BATCH]) - +- Calc(select=[sum_b, w0$o0 AS rn, c, e]) + +- Exchange(distribution=[hash[c, e, $f5]], exchange_mode=[BATCH]) + +- Calc(select=[sum_b, c, e, -(w0$o0, 1) AS $f5]) +- OverAggregate(partitionBy=[c, e], orderBy=[c ASC, e ASC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[c, e, sum_b, w0$o0]) +- Reused(reference_id=[2]) ]]> @@ -291,15 +291,17 @@ LogicalProject(c=[$0], a=[$1], b=[$2], c0=[$3], a0=[$4], b0=[$5]) (a, 1)]) - +- Reused(reference_id=[1]) +Calc(select=[c, a, b, c0, a1, b0]) ++- HashJoin(joinType=[InnerJoin], where=[=(a0, b0)], select=[c, a, b, a0, c0, a1, b0], build=[right]) + :- Exchange(distribution=[hash[a0]], exchange_mode=[BATCH]) + : +- Calc(select=[c, a, b, CAST(a) AS a0]) + : +- SortAggregate(isMerge=[false], groupBy=[c], select=[c, MyFirst(a) AS a, MyLast(b) AS b], reuse_id=[1]) + : +- Sort(orderBy=[c ASC]) + : +- Exchange(distribution=[hash[c]]) + : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + +- Exchange(distribution=[hash[b]]) + +- Calc(select=[c, a, b], where=[>(a, 1)]) + +- Reused(reference_id=[1]) ]]> @@ -465,6 +467,71 @@ NestedLoopJoin(joinType=[InnerJoin], where=[=(c, f00)], select=[a, b, c, f0, a0, : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +- Correlate(invocation=[TableFun($cor1.c)], correlate=[table(TableFun($cor1.c))], select=[a,b,c,f0], rowType=[RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) f0)], joinType=[INNER]) +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + + + + + + 10) +UNION ALL +(SELECT a FROM t WHERE b > 10) + ]]> + + + ($0, 10)]) +: +- LogicalTableScan(table=[[t]]) ++- LogicalProject(a=[$0]) + +- LogicalFilter(condition=[>($1, 10)]) + +- LogicalTableScan(table=[[t]]) +]]> + + + (a, 10)]) +: +- BoundedStreamScan(table=[[t]], fields=[a, b, c], reuse_id=[1]) ++- Calc(select=[a], where=[>(b, 10)]) + +- Reused(reference_id=[1]) ]]> @@ -536,15 +603,17 @@ LogicalProject(c=[$0], a=[$1], b=[$2], c0=[$3], a0=[$4], b0=[$5]) (a, 1)]) - +- Reused(reference_id=[1]) +Calc(select=[c, a, b, c0, a1, b0]) ++- HashJoin(joinType=[InnerJoin], where=[=(a0, b0)], select=[c, a, b, a0, c0, a1, b0], build=[right]) + :- Exchange(distribution=[hash[a0]], exchange_mode=[BATCH]) + : +- Calc(select=[c, a, b, CAST(a) AS a0]) + : +- HashAggregate(isMerge=[true], groupBy=[c], select=[c, Final_SUM(sum$0) AS a, Final_SUM(sum$1) AS b], reuse_id=[1]) + : +- Exchange(distribution=[hash[c]]) + : +- LocalHashAggregate(groupBy=[c], select=[c, Partial_SUM(a) AS sum$0, Partial_SUM(b) AS sum$1]) + : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + +- Exchange(distribution=[hash[b]]) + +- Calc(select=[c, a, b], where=[>(a, 1)]) + +- Reused(reference_id=[1]) ]]> @@ -614,16 +683,18 @@ LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], f=[$5], a0=[$6], b0=[$7], (b, 1)]) -: : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) -: +- Calc(select=[d, e, f], where=[<(e, 2)]) -: +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) -+- Exchange(distribution=[any], exchange_mode=[BATCH]) - +- Reused(reference_id=[1]) +Calc(select=[a, b, c, d, e, f, a1, b0, c0, d0, e0, f0]) ++- HashJoin(joinType=[InnerJoin], where=[=(a0, b0)], select=[a, b, c, d, e, f, a0, a1, b0, c0, d0, e0, f0], build=[right]) + :- Exchange(distribution=[hash[a0]], exchange_mode=[BATCH]) + : +- Calc(select=[a, b, c, d, e, f, CAST(a) AS a0]) + : +- NestedLoopJoin(joinType=[InnerJoin], where=[OR(=(ABS(a), ABS(d)), =(c, f))], select=[a, b, c, d, e, f], build=[left], reuse_id=[1]) + : :- Exchange(distribution=[broadcast]) + : : +- Calc(select=[a, b, c], where=[>(b, 1)]) + : : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + : +- Calc(select=[d, e, f], where=[<(e, 2)]) + : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) + +- Exchange(distribution=[hash[b]]) + +- Reused(reference_id=[1]) ]]> @@ -654,16 +725,18 @@ LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], f=[$5], a0=[$6], b0=[$7], (b, 1)]) -: : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) -: +- Calc(select=[d, e, f], where=[<(e, 2)]) -: +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) -+- Exchange(distribution=[any], exchange_mode=[BATCH]) - +- Reused(reference_id=[1]) +Calc(select=[a, b, c, d, e, f, a1, b0, c0, d0, e0, f0]) ++- HashJoin(joinType=[InnerJoin], where=[=(a0, b0)], select=[a, b, c, d, e, f, a0, a1, b0, c0, d0, e0, f0], build=[right]) + :- Exchange(distribution=[hash[a0]], exchange_mode=[BATCH]) + : +- Calc(select=[a, b, c, d, e, f, CAST(a) AS a0]) + : +- NestedLoopJoin(joinType=[InnerJoin], where=[OR(=(random_udf(a), random_udf(d)), =(c, f))], select=[a, b, c, d, e, f], build=[left], reuse_id=[1]) + : :- Exchange(distribution=[broadcast]) + : : +- Calc(select=[a, b, c], where=[>(b, 1)]) + : : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + : +- Calc(select=[d, e, f], where=[<(e, 2)]) + : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) + +- Exchange(distribution=[hash[b]]) + +- Reused(reference_id=[1]) ]]> @@ -689,16 +762,17 @@ LogicalProject(a=[$0], b=[$1], a0=[$2]) @@ -737,39 +811,6 @@ NestedLoopJoin(joinType=[InnerJoin], where=[=(a, d0)], select=[a, b, c, d, e, f, : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) +- Exchange(distribution=[any], exchange_mode=[BATCH]) +- Reused(reference_id=[1]) -]]> - - - - - 10 - ]]> - - - ($4, 10))]) - +- LogicalJoin(condition=[true], joinType=[inner]) - :- LogicalProject(a=[$0], b=[$1], EXPR$2=[RANK() OVER (ORDER BY $2 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) - : +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) - +- LogicalProject(a=[$0], b=[$1], EXPR$2=[RANK() OVER (ORDER BY $2 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) - +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) -]]> - - - (b, 10)]) - +- Reused(reference_id=[1]) ]]> @@ -841,73 +882,73 @@ HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[c, a, b, c0, a0, b0], b ]]> - + 1 ]]> ($4, 1))]) +- LogicalJoin(condition=[true], joinType=[inner]) - :- LogicalSort(fetch=[10]) - : +- LogicalProject(a=[$0], b=[$1]) + :- LogicalAggregate(group=[{0}], a=[SUM($1)], b=[SUM($2)]) + : +- LogicalProject(c=[$2], a=[$0], b=[$1]) : +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) - +- LogicalSort(fetch=[10]) - +- LogicalProject(a=[$0], b=[$1]) + +- LogicalAggregate(group=[{0}], a=[SUM($1)], b=[SUM($2)]) + +- LogicalProject(c=[$2], a=[$0], b=[$1]) +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) ]]> (a, 1)]) + +- Reused(reference_id=[1]) ]]> - + 1 +WITH r AS (SELECT a, b, RANK() OVER (ORDER BY c DESC) FROM x) +SELECT * FROM r r1, r r2 WHERE r1.a = r2.a AND r1.b < 100 AND r2.b > 10 ]]> ($4, 1))]) +LogicalProject(a=[$0], b=[$1], EXPR$2=[$2], a0=[$3], b0=[$4], EXPR$20=[$5]) ++- LogicalFilter(condition=[AND(=($0, $3), <($1, 100), >($4, 10))]) +- LogicalJoin(condition=[true], joinType=[inner]) - :- LogicalAggregate(group=[{0}], a=[SUM($1)], b=[SUM($2)]) - : +- LogicalProject(c=[$2], a=[$0], b=[$1]) - : +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) - +- LogicalAggregate(group=[{0}], a=[SUM($1)], b=[SUM($2)]) - +- LogicalProject(c=[$2], a=[$0], b=[$1]) - +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) + :- LogicalProject(a=[$0], b=[$1], EXPR$2=[RANK() OVER (ORDER BY $2 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) + : +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) + +- LogicalProject(a=[$0], b=[$1], EXPR$2=[RANK() OVER (ORDER BY $2 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) ]]> (a, 1)]) +HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, $2, a0, b0, $20], build=[right]) +:- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) +: +- Calc(select=[a, b, w0$o0 AS $2], where=[<(b, 100)]) +: +- OverAggregate(orderBy=[c DESC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0], reuse_id=[1]) +: +- Sort(orderBy=[c DESC]) +: +- Exchange(distribution=[single]) +: +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ++- Exchange(distribution=[hash[a]]) + +- Calc(select=[a, b, w0$o0 AS $2], where=[>(b, 10)]) +- Reused(reference_id=[1]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashJoinTest.xml index 2d0e4d9b8f5d2..b61773f230011 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashJoinTest.xml @@ -157,6 +157,31 @@ Calc(select=[c, g]) : +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(d, e, f, g, h)]]], fields=[d, e, f, g, h]) +- Exchange(distribution=[broadcast]) +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.xml new file mode 100644 index 0000000000000..56faaac15412e --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.xml @@ -0,0 +1,335 @@ + + + + + + 10 + ) AS T +GROUP BY b + ]]> + + + ($7, 10)]) + +- LogicalFilter(condition=[=($1, $5)]) + +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{4}]) + :- LogicalProject(b=[$1], a=[$0], c=[$2], d=[$3], proctime=[PROCTIME()]) + : +- LogicalAggregate(group=[{0, 1}], c=[SUM($2)], d=[SUM($3)]) + : +- LogicalTableScan(table=[[T1]]) + +- LogicalSnapshot(period=[$cor0.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (age, 10)], select=[b, a, c, d, id]) + +- Calc(select=[b, a, c, d]) + +- HashAggregate(isMerge=[true], groupBy=[a, b], select=[a, b, Final_SUM(sum$0) AS c, Final_SUM(sum$1) AS d]) + +- Exchange(distribution=[hash[a, b]]) + +- LocalHashAggregate(groupBy=[a, b], select=[a, b, Partial_SUM(c) AS sum$0, Partial_SUM(d) AS sum$1]) + +- BoundedStreamScan(table=[[T1]], fields=[a, b, c, d]) +]]> + + + + + + + + + + + + + + + + 1000 + ]]> + + + ($2, 1000)]) + +- LogicalFilter(condition=[AND(=($0, $4), =($6, 10))]) + +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3}]) + :- LogicalTableScan(table=[[MyTable]]) + +- LogicalSnapshot(period=[$cor0.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (c, 1000)]) + +- BoundedStreamScan(table=[[T0]], fields=[a, b, c]) +]]> + + + + + 1000) AS T JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id]]> + + + ($2, 1000)]) + : +- LogicalTableScan(table=[[MyTable]]) + +- LogicalSnapshot(period=[$cor0.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (c, 1000)]) + +- BoundedStreamScan(table=[[T0]], fields=[a, b, c]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 10 + ) AS T +GROUP BY b + ]]> + + + ($7, 10)]) + +- LogicalFilter(condition=[=($1, $5)]) + +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{4}]) + :- LogicalProject(b=[$1], a=[$0], c=[$2], d=[$3], proctime=[PROCTIME()]) + : +- LogicalAggregate(group=[{0, 1}], c=[SUM($2)], d=[SUM($3)]) + : +- LogicalTableScan(table=[[T1]]) + +- LogicalSnapshot(period=[$cor0.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (age, 10)]) + +- FlinkLogicalTableSourceScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]], fields=[id, name, age]) +]]> + + + + + 10 + ) AS T1, ( +SELECT id as a, b FROM ( +SELECT * FROM ( +SELECT b, a, sum(c) c, sum(d) d, PROCTIME() as proctime +FROM T1 +GROUP BY a, b + ) AS T +JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D +ON T.a = D.id +WHERE D.age > 10 + ) AS T + ) AS T2 +WHERE T1.a = T2.a +GROUP BY T1.b, T2.b + ]]> + + + ($7, 10)]) + : +- LogicalFilter(condition=[=($1, $5)]) + : +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{4}]) + : :- LogicalProject(b=[$1], a=[$0], c=[$2], d=[$3], proctime=[PROCTIME()]) + : : +- LogicalAggregate(group=[{0, 1}], c=[SUM($2)], d=[SUM($3)]) + : : +- LogicalTableScan(table=[[T1]]) + : +- LogicalSnapshot(period=[$cor0.proctime]) + : +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) + +- LogicalProject(a=[$5], b=[$0]) + +- LogicalFilter(condition=[>($7, 10)]) + +- LogicalFilter(condition=[=($1, $5)]) + +- LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{4}]) + :- LogicalProject(b=[$1], a=[$0], c=[$2], d=[$3], proctime=[PROCTIME()]) + : +- LogicalAggregate(group=[{0, 1}], c=[SUM($2)], d=[SUM($3)]) + : +- LogicalTableScan(table=[[T1]]) + +- LogicalSnapshot(period=[$cor1.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (age, 10)], select=[b, a, id], reuse_id=[1]) + : +- Calc(select=[b, a]) + : +- HashAggregate(isMerge=[true], groupBy=[a, b], select=[a, b, Final_SUM(sum$0) AS c, Final_SUM(sum$1) AS d]) + : +- Exchange(distribution=[hash[a, b]]) + : +- LocalHashAggregate(groupBy=[a, b], select=[a, b, Partial_SUM(c) AS sum$0, Partial_SUM(d) AS sum$1]) + : +- BoundedStreamScan(table=[[T1]], fields=[a, b, c, d]) + +- Exchange(distribution=[hash[a]]) + +- Calc(select=[id AS a, b]) + +- Reused(reference_id=[1]) +]]> + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopJoinTest.xml index cdfb68b6378f8..a0d0dc393371f 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopJoinTest.xml @@ -338,12 +338,13 @@ LogicalProject(a=[$0], d=[$3]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.xml index 8a2dff73d3362..4be436f90b7a9 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.xml @@ -239,6 +239,32 @@ Calc(select=[c, g]) : +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(d, e, f, g, h)]]], fields=[d, e, f, g, h]) +- Exchange(distribution=[hash[a, b]]) +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.xml index 6118dd209298b..d21ac60d3ee28 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.xml @@ -239,6 +239,32 @@ Calc(select=[c, g]) : +- TableSourceScan(table=[[MyTable2, source: [TestTableSource(d, e, f, g, h)]]], fields=[d, e, f, g, h]) +- Exchange(distribution=[hash[a, b]]) +- TableSourceScan(table=[[MyTable1, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.xml index bdb347de68a61..07639da0ebab8 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.xml @@ -154,14 +154,16 @@ LogicalProject(c=[$0], a=[$1], b=[$2], c0=[$3], a0=[$4], b0=[$5]) (a, 1)]) - +- Reused(reference_id=[1]) +Calc(select=[c, a, b, c0, a1, b0]) ++- Join(joinType=[InnerJoin], where=[=(a0, b0)], select=[c, a, b, a0, c0, a1, b0]) + :- Exchange(distribution=[hash[a0]]) + : +- Calc(select=[c, a, b, CAST(a) AS a0]) + : +- GroupAggregate(groupBy=[c], select=[c, MyFirst(a) AS a, MyLast(b) AS b], reuse_id=[1]) + : +- Exchange(distribution=[hash[c]]) + : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + +- Exchange(distribution=[hash[b]]) + +- Calc(select=[c, a, b], where=[>(a, 1)]) + +- Reused(reference_id=[1]) ]]> @@ -372,14 +374,16 @@ LogicalProject(c=[$0], a=[$1], b=[$2], c0=[$3], a0=[$4], b0=[$5]) (a, 1)]) - +- Reused(reference_id=[1]) +Calc(select=[c, a, b, c0, a1, b0]) ++- Join(joinType=[InnerJoin], where=[=(a0, b0)], select=[c, a, b, a0, c0, a1, b0]) + :- Exchange(distribution=[hash[a0]]) + : +- Calc(select=[c, a, b, CAST(a) AS a0]) + : +- GroupAggregate(groupBy=[c], select=[c, SUM(a) AS a, SUM(b) AS b], reuse_id=[1]) + : +- Exchange(distribution=[hash[c]]) + : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + +- Exchange(distribution=[hash[b]]) + +- Calc(select=[c, a, b], where=[>(a, 1)]) + +- Reused(reference_id=[1]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/LookupJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/LookupJoinTest.xml new file mode 100644 index 0000000000000..c4308d7bbfc14 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/LookupJoinTest.xml @@ -0,0 +1,267 @@ + + + + + + 1000 + ]]> + + + (CAST($6):BIGINT, 1000)]) + +- LogicalFilter(condition=[AND(=($0, $5), =($7, 10))]) + +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3}]) + :- LogicalTableScan(table=[[MyTable]]) + +- LogicalSnapshot(period=[$cor0.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (CAST(name), 1000))], select=[a, b, c, proctime, rowtime, id, name, age]) + +- DataStreamScan(table=[[_DataStreamTable_0]], fields=[a, b, c, proctime, rowtime]) +]]> + + + + + 10 + ) AS T +GROUP BY b + ]]> + + + ($7, 10)]) + +- LogicalFilter(condition=[=($1, $5)]) + +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{4}]) + :- LogicalProject(b=[$1], a=[$0], c=[$2], d=[$3], proc=[PROCTIME()]) + : +- LogicalAggregate(group=[{0, 1}], c=[SUM($2)], d=[SUM($3)]) + : +- LogicalTableScan(table=[[T1]]) + +- LogicalSnapshot(period=[$cor0.proc]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (age, 10)], select=[b, a, c, d, id]) + +- Calc(select=[b, a, c, d]) + +- GroupAggregate(groupBy=[a, b], select=[a, b, SUM(c) AS c, SUM(d) AS d]) + +- Exchange(distribution=[hash[a, b]]) + +- DataStreamScan(table=[[_DataStreamTable_1]], fields=[a, b, c, d]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + 1000 + ]]> + + + ($2, 1000)]) + +- LogicalFilter(condition=[AND(=($0, $5), =($7, 10))]) + +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3}]) + :- LogicalTableScan(table=[[MyTable]]) + +- LogicalSnapshot(period=[$cor0.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (c, 1000)]) + +- DataStreamScan(table=[[_DataStreamTable_0]], fields=[a, b, c, proctime, rowtime]) +]]> + + + + + 1000 + ]]> + + + ($2, 1000)]) + +- LogicalFilter(condition=[AND(=($0, $5), =($7, 10), =($6, _UTF-16LE'AAA'))]) + +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3}]) + :- LogicalTableScan(table=[[MyTable]]) + +- LogicalSnapshot(period=[$cor0.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (c, 1000)]) + +- DataStreamScan(table=[[_DataStreamTable_0]], fields=[a, b, c, proctime, rowtime]) +]]> + + + + + 1000) AS T JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id]]> + + + ($2, 1000)]) + : +- LogicalTableScan(table=[[MyTable]]) + +- LogicalSnapshot(period=[$cor0.proctime]) + +- LogicalTableScan(table=[[temporalTest, source: [TestTemporalTable(id, name, age)]]]) +]]> + + + (c, 1000)]) + +- DataStreamScan(table=[[_DataStreamTable_0]], fields=[a, b, c, proctime, rowtime]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/api/batch/ExplainTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/api/batch/ExplainTest.scala index 2e0c8eeb4be7b..abb87e1bbd394 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/api/batch/ExplainTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/api/batch/ExplainTest.scala @@ -26,7 +26,7 @@ import org.junit.Test class ExplainTest extends TableTestBase { private val util = batchTestUtil() - util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addDataStream[(Int, Long, String)]("MyTable", 'a, 'b, 'c) @Test def testExplainTableSourceScan(): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/api/stream/ExplainTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/api/stream/ExplainTest.scala index ff3807833ebae..96f7a3ad3c0ba 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/api/stream/ExplainTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/api/stream/ExplainTest.scala @@ -26,7 +26,7 @@ import org.junit.Test class ExplainTest extends TableTestBase { private val util = streamTestUtil() - util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addDataStream[(Int, Long, String)]("MyTable", 'a, 'b, 'c) @Test def testExplainTableSourceScan(): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.scala index 2bf4d43493d8a..f85d8de49dd3f 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.scala @@ -110,8 +110,7 @@ class SubplanReuseTest extends TableTestBase { util.verifyPlan(sqlQuery) } - - @Test(expected = classOf[TableException]) + @Test def testSubplanReuseOnDataStreamTable(): Unit = { util.addDataStream[(Int, Long, String)]("t", 'a, 'b, 'c) val sqlQuery = diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/BroadcastHashJoinTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/BroadcastHashJoinTest.scala index 213b44596745c..0bde6c9eb25bf 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/BroadcastHashJoinTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/BroadcastHashJoinTest.scala @@ -47,13 +47,6 @@ class BroadcastHashJoinTest extends JoinTestBase { super.testInnerJoinWithoutJoinPred() } - @Test - override def testInnerJoinWithNonEquiPred(): Unit = { - thrown.expect(classOf[TableException]) - thrown.expectMessage("Cannot generate a valid execution plan for the given query") - super.testInnerJoinWithNonEquiPred() - } - @Test override def testLeftOuterJoinNoEquiPred(): Unit = { thrown.expect(classOf[TableException]) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.scala new file mode 100644 index 0000000000000..94f015e80783a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.scala @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.batch.sql.join + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api._ +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.FlinkBatchProgram +import org.apache.flink.table.plan.stream.sql.join.TestTemporalTable +import org.apache.flink.table.util.TableTestBase +import org.junit.Assert.{assertTrue, fail} +import org.junit.{Before, Ignore, Test} + +class LookupJoinTest extends TableTestBase { + private val testUtil = batchTestUtil() + + @Before + def before(): Unit = { + testUtil.addDataStream[(Int, String, Long)]("T0", 'a, 'b, 'c) + testUtil.addDataStream[(Int, String, Long, Double)]("T1", 'a, 'b, 'c, 'd) + testUtil.addDataStream[(Int, String, Int)]("nonTemporal", 'id, 'name, 'age) + testUtil.tableEnv.registerTableSource("temporalTest", new TestTemporalTable) + val myTable = testUtil.tableEnv.sqlQuery("SELECT *, PROCTIME() as proctime FROM T0") + testUtil.tableEnv.registerTable("MyTable", myTable) + } + + @Test + def testJoinInvalidJoinTemporalTable(): Unit = { + // must follow a period specification + expectExceptionThrown( + "SELECT * FROM MyTable AS T JOIN temporalTest T.proc AS D ON T.a = D.id", + "SQL parse failed", + classOf[SqlParserException]) + + // can't query a dim table directly + expectExceptionThrown( + "SELECT * FROM temporalTest FOR SYSTEM_TIME AS OF TIMESTAMP '2017-08-09 14:36:11'", + "Cannot generate a valid execution plan for the given query", + classOf[TableException] + ) + + // can't on non-key fields + expectExceptionThrown( + "SELECT * FROM MyTable AS T JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.age", + "Temporal table join requires an equality condition on ALL fields of table " + + "[TestTemporalTable(id, name, age)]'s PRIMARY KEY or (UNIQUE) INDEX(s).", + classOf[TableException] + ) + + // only support left or inner join + expectExceptionThrown( + "SELECT * FROM MyTable AS T RIGHT JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id", + "Unsupported join type for semi-join RIGHT", + classOf[IllegalArgumentException] + ) + + // only support join on raw key of right table + expectExceptionThrown( + "SELECT * FROM MyTable AS T LEFT JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a + 1 = D.id + 1", + "Temporal table join requires an equality condition on ALL fields of table " + + "[TestTemporalTable(id, name, age)]'s PRIMARY KEY or (UNIQUE) INDEX(s).", + classOf[TableException] + ) + } + + + @Test + def testLogicalPlan(): Unit = { + val sql1 = + """ + |SELECT b, a, sum(c) c, sum(d) d, PROCTIME() as proctime + |FROM T1 + |GROUP BY a, b + """.stripMargin + + val sql2 = + s""" + |SELECT T.* FROM ($sql1) AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id + |WHERE D.age > 10 + """.stripMargin + + val sql = + s""" + |SELECT b, count(a), sum(c), sum(d) + |FROM ($sql2) AS T + |GROUP BY b + """.stripMargin + val programs = FlinkBatchProgram.buildProgram(testUtil.tableEnv.getConfig.getConf) + programs.remove(FlinkBatchProgram.PHYSICAL) + val calciteConfig = CalciteConfig.createBuilder(testUtil.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + testUtil.tableEnv.getConfig.setCalciteConfig(calciteConfig) + testUtil.verifyPlan(sql) + } + + @Test + def testLogicalPlanWithImplicitTypeCast(): Unit = { + val programs = FlinkBatchProgram.buildProgram(testUtil.tableEnv.getConfig.getConf) + programs.remove(FlinkBatchProgram.PHYSICAL) + val calciteConfig = CalciteConfig.createBuilder(testUtil.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + testUtil.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + testUtil.verifyPlan("SELECT * FROM MyTable AS T JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.b = D.id") + } + + @Test + def testJoinInvalidNonTemporalTable(): Unit = { + // can't follow a period specification + expectExceptionThrown( + "SELECT * FROM MyTable AS T JOIN nonTemporal " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id", + "Temporal table join only support join on a LookupableTableSource", + classOf[TableException]) + } + + @Test + def testJoinTemporalTable(): Unit = { + val sql = "SELECT * FROM MyTable AS T JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id" + testUtil.verifyPlan(sql) + } + + @Test + def testLeftJoinTemporalTable(): Unit = { + val sql = "SELECT * FROM MyTable AS T LEFT JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id" + testUtil.verifyPlan(sql) + } + + @Test + def testJoinTemporalTableWithNestedQuery(): Unit = { + val sql = "SELECT * FROM " + + "(SELECT a, b, proctime FROM MyTable WHERE c > 1000) AS T " + + "JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id" + testUtil.verifyPlan(sql) + } + + @Test + def testJoinTemporalTableWithProjectionPushDown(): Unit = { + val sql = + """ + |SELECT T.*, D.id + |FROM MyTable AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id + """.stripMargin + testUtil.verifyPlan(sql) + } + + @Test + def testJoinTemporalTableWithFilterPushDown(): Unit = { + val sql = + """ + |SELECT * FROM MyTable AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id AND D.age = 10 + |WHERE T.c > 1000 + """.stripMargin + testUtil.verifyPlan(sql) + } + + @Test + def testAvoidAggregatePushDown(): Unit = { + val sql1 = + """ + |SELECT b, a, sum(c) c, sum(d) d, PROCTIME() as proctime + |FROM T1 + |GROUP BY a, b + """.stripMargin + + val sql2 = + s""" + |SELECT T.* FROM ($sql1) AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id + |WHERE D.age > 10 + """.stripMargin + + val sql = + s""" + |SELECT b, count(a), sum(c), sum(d) + |FROM ($sql2) AS T + |GROUP BY b + """.stripMargin + testUtil.verifyPlan(sql) + } + + @Test + def testReusing(): Unit = { + testUtil.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_REUSE_SUB_PLAN_ENABLED, true) + val sql1 = + """ + |SELECT b, a, sum(c) c, sum(d) d, PROCTIME() as proctime + |FROM T1 + |GROUP BY a, b + """.stripMargin + + val sql2 = + s""" + |SELECT * FROM ($sql1) AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id + |WHERE D.age > 10 + """.stripMargin + val sql3 = + s""" + |SELECT id as a, b FROM ($sql2) AS T + """.stripMargin + val sql = + s""" + |SELECT count(T1.a), count(T1.id), sum(T2.a) + |FROM ($sql2) AS T1, ($sql3) AS T2 + |WHERE T1.a = T2.a + |GROUP BY T1.b, T2.b + """.stripMargin + + testUtil.verifyPlan(sql) + } + + // ========================================================================================== + + // ========================================================================================== + + private def expectExceptionThrown( + sql: String, + keywords: String, + clazz: Class[_ <: Throwable] = classOf[ValidationException]) + : Unit = { + try { + testUtil.verifyExplain(sql) + fail(s"Expected a $clazz, but no exception is thrown.") + } catch { + case e if e.getClass == clazz => + if (keywords != null) { + assertTrue( + s"The actual exception message \n${e.getMessage}\n" + + s"doesn't contain expected keyword \n$keywords\n", + e.getMessage.contains(keywords)) + } + case e: Throwable => + e.printStackTrace() + fail(s"Expected throw ${clazz.getSimpleName}, but is $e.") + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.scala index 904868e92326b..3199f9373e59b 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.scala @@ -45,13 +45,6 @@ class ShuffledHashJoinTest extends JoinTestBase { super.testInnerJoinWithoutJoinPred() } - @Test - override def testInnerJoinWithNonEquiPred(): Unit = { - thrown.expect(classOf[TableException]) - thrown.expectMessage("Cannot generate a valid execution plan for the given query") - super.testInnerJoinWithNonEquiPred() - } - @Test override def testLeftOuterJoinNoEquiPred(): Unit = { thrown.expect(classOf[TableException]) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/SingleRowJoinTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/SingleRowJoinTest.scala index 705bc9680c2e0..cf1025c35663c 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/SingleRowJoinTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/SingleRowJoinTest.scala @@ -32,13 +32,6 @@ class SingleRowJoinTest extends TableTestBase { util.verifyPlan("SELECT a1, a_sum FROM A, (SELECT SUM(a1) + SUM(a2) AS a_sum FROM A)") } - @Test - def testSingleRowEquiJoin(): Unit = { - val util = batchTestUtil() - util.addTableSource[(Int, String)]("A", 'a1, 'a2) - util.verifyPlan("SELECT a1, a2 FROM A, (SELECT COUNT(a1) AS cnt FROM A) WHERE a1 = cnt") - } - @Test def testSingleRowNotEquiJoin(): Unit = { val util = batchTestUtil() diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.scala index 495b2fb6ec9ae..06401bd17df18 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.scala @@ -37,13 +37,6 @@ class SortMergeJoinTest extends JoinTestBase { super.testInnerJoinWithoutJoinPred() } - @Test - override def testInnerJoinWithNonEquiPred(): Unit = { - thrown.expect(classOf[TableException]) - thrown.expectMessage("Cannot generate a valid execution plan for the given query") - super.testInnerJoinWithNonEquiPred() - } - @Test override def testLeftOuterJoinNoEquiPred(): Unit = { thrown.expect(classOf[TableException]) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/join/LookupJoinTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/join/LookupJoinTest.scala new file mode 100644 index 0000000000000..0b210f53cccc5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/join/LookupJoinTest.scala @@ -0,0 +1,498 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.stream.sql.join + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.streaming.api.functions.async.ResultFuture +import org.apache.flink.table.api._ +import org.apache.flink.table.dataformat.{BaseRow, BinaryString} +import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction} +import org.apache.flink.table.sources._ +import org.apache.flink.api.scala._ +import org.apache.flink.table.`type`.{InternalType, InternalTypes} +import org.apache.flink.table.typeutils.BaseRowTypeInfo +import org.apache.flink.table.util.{StreamTableTestUtil, TableTestBase} +import org.apache.flink.types.Row +import org.junit.Assert.{assertTrue, fail} +import org.junit.Test + +import _root_.java.util +import _root_.java.lang.{Long => JLong} +import _root_.java.sql.Timestamp + +import _root_.scala.annotation.varargs + +class LookupJoinTest extends TableTestBase with Serializable { + private val streamUtil: StreamTableTestUtil = streamTestUtil() + streamUtil.addDataStream[(Int, String, Long)]("MyTable", 'a, 'b, 'c, 'proctime, 'rowtime) + streamUtil.addDataStream[(Int, String, Long, Double)]("T1", 'a, 'b, 'c, 'd) + streamUtil.addDataStream[(Int, String, Int)]("nonTemporal", 'id, 'name, 'age) + streamUtil.tableEnv.registerTableSource("temporalTest", new TestTemporalTable) + + @Test + def testJoinInvalidJoinTemporalTable(): Unit = { + // must follow a period specification + expectExceptionThrown( + "SELECT * FROM MyTable AS T JOIN temporalTest T.proctime AS D ON T.a = D.id", + "SQL parse failed", + classOf[SqlParserException]) + + // can't as of non-proctime field + expectExceptionThrown( + "SELECT * FROM MyTable AS T JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.rowtime AS D ON T.a = D.id", + "Temporal table join currently only supports 'FOR SYSTEM_TIME AS OF' " + + "left table's proctime field", + classOf[TableException]) + + // can't query a dim table directly + expectExceptionThrown( + "SELECT * FROM temporalTest FOR SYSTEM_TIME AS OF TIMESTAMP '2017-08-09 14:36:11'", + "Cannot generate a valid execution plan for the given query", + classOf[TableException] + ) + + // can't on non-key fields + expectExceptionThrown( + "SELECT * FROM MyTable AS T JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.age", + "Temporal table join requires an equality condition on ALL fields of table " + + "[TestTemporalTable(id, name, age)]'s PRIMARY KEY or (UNIQUE) INDEX(s).", + classOf[TableException] + ) + + // only support left or inner join + expectExceptionThrown( + "SELECT * FROM MyTable AS T RIGHT JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id", + "Unsupported join type for semi-join RIGHT", + classOf[IllegalArgumentException] + ) + + // only support join on raw key of right table + expectExceptionThrown( + "SELECT * FROM MyTable AS T LEFT JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a + 1 = D.id + 2", + "Temporal table join requires an equality condition on ALL fields of table " + + "[TestTemporalTable(id, name, age)]'s PRIMARY KEY or (UNIQUE) INDEX(s).", + classOf[TableException] + ) + + // only support "FOR SYSTEM_TIME AS OF" left table's proctime + expectExceptionThrown( + "SELECT * FROM MyTable AS T LEFT JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF PROCTIME() AS D ON T.a = D.id", + "Temporal table join currently only supports 'FOR SYSTEM_TIME AS OF' " + + "left table's proctime field, doesn't support 'PROCTIME()'", + classOf[TableException] + ) + } + + @Test + def testInvalidLookupTableFunction(): Unit = { + streamUtil.addDataStream[(Int, String, Long, Timestamp)]("T", 'a, 'b, 'c, 'ts, 'proctime) + + val temporalTable = new TestInvalidTemporalTable(new InvalidTableFunctionResultType) + streamUtil.tableEnv.registerTableSource("temporalTable", temporalTable) + expectExceptionThrown( + "SELECT * FROM T AS T JOIN temporalTable " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id AND T.b = D.name AND T.ts = D.ts", + "The TableSource [TestInvalidTemporalTable(id, name, age, ts)] " + + "return type BaseRow(id: Integer, name: String, age: Integer, ts: Timestamp) " + + "does not match its lookup function extracted return type String", + classOf[TableException] + ) + + val temporalTable2 = new TestInvalidTemporalTable(new InvalidTableFunctionEvalSignature1) + streamUtil.tableEnv.registerTableSource("temporalTable2", temporalTable2) + expectExceptionThrown( + "SELECT * FROM T AS T JOIN temporalTable2 " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id AND T.b = D.name AND T.ts = D.ts", + "Expected: eval(java.lang.Integer, org.apache.flink.table.dataformat.BinaryString, " + + "java.lang.Long) \n" + + "Actual: eval(java.lang.Integer, java.lang.String, java.sql.Timestamp)", + classOf[TableException] + ) + + val temporalTable3 = new TestInvalidTemporalTable(new ValidTableFunction) + streamUtil.tableEnv.registerTableSource("temporalTable3", temporalTable3) + verifyTranslationSuccess("SELECT * FROM T AS T JOIN temporalTable3 " + + "FOR SYSTEM_TIME AS OF T.proctime AS D " + + "ON T.a = D.id AND T.b = D.name AND T.ts = D.ts") + + val temporalTable4 = new TestInvalidTemporalTable(new ValidTableFunction2) + streamUtil.tableEnv.registerTableSource("temporalTable4", temporalTable4) + verifyTranslationSuccess("SELECT * FROM T AS T JOIN temporalTable4 " + + "FOR SYSTEM_TIME AS OF T.proctime AS D " + + "ON T.a = D.id AND T.b = D.name AND T.ts = D.ts") + + val temporalTable5 = new TestInvalidTemporalTable(new ValidAsyncTableFunction) + streamUtil.tableEnv.registerTableSource("temporalTable5", temporalTable5) + verifyTranslationSuccess("SELECT * FROM T AS T JOIN temporalTable5 " + + "FOR SYSTEM_TIME AS OF T.proctime AS D " + + "ON T.a = D.id AND T.b = D.name AND T.ts = D.ts") + + val temporalTable6 = new TestInvalidTemporalTable(new InvalidAsyncTableFunctionResultType) + streamUtil.tableEnv.registerTableSource("temporalTable6", temporalTable6) + verifyTranslationSuccess("SELECT * FROM T AS T JOIN temporalTable6 " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id AND T.b = D.name AND T.ts = D.ts") + + val temporalTable7 = new TestInvalidTemporalTable(new InvalidAsyncTableFunctionEvalSignature1) + streamUtil.tableEnv.registerTableSource("temporalTable7", temporalTable7) + expectExceptionThrown( + "SELECT * FROM T AS T JOIN temporalTable7 " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id AND T.b = D.name AND T.ts = D.ts", + "Expected: eval(org.apache.flink.streaming.api.functions.async.ResultFuture, " + + "java.lang.Integer, org.apache.flink.table.dataformat.BinaryString, java.lang.Long) \n" + + "Actual: eval(java.lang.Integer, org.apache.flink.table.dataformat.BinaryString, " + + "java.sql.Timestamp)", + classOf[TableException] + ) + + val temporalTable8 = new TestInvalidTemporalTable(new InvalidAsyncTableFunctionEvalSignature2) + streamUtil.tableEnv.registerTableSource("temporalTable8", temporalTable8) + expectExceptionThrown( + "SELECT * FROM T AS T JOIN temporalTable8 " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id AND T.b = D.name AND T.ts = D.ts", + "Expected: eval(org.apache.flink.streaming.api.functions.async.ResultFuture, " + + "java.lang.Integer, org.apache.flink.table.dataformat.BinaryString, java.lang.Long) \n" + + "Actual: eval(org.apache.flink.streaming.api.functions.async.ResultFuture, " + + "java.lang.Integer, java.lang.String, java.sql.Timestamp)", + classOf[TableException] + ) + + val temporalTable9 = new TestInvalidTemporalTable(new ValidAsyncTableFunction) + streamUtil.tableEnv.registerTableSource("temporalTable9", temporalTable9) + verifyTranslationSuccess("SELECT * FROM T AS T JOIN temporalTable9 " + + "FOR SYSTEM_TIME AS OF T.proctime AS D " + + "ON T.a = D.id AND T.b = D.name AND T.ts = D.ts") + } + + @Test + def testJoinOnDifferentKeyTypes(): Unit = { + // Will do implicit type coercion. + streamUtil.verifyPlan("SELECT * FROM MyTable AS T JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.b = D.id") + } + + @Test + def testJoinInvalidNonTemporalTable(): Unit = { + // can't follow a period specification + expectExceptionThrown( + "SELECT * FROM MyTable AS T JOIN nonTemporal " + + "FOR SYSTEM_TIME AS OF T.rowtime AS D ON T.a = D.id", + "Table 'nonTemporal' is not a temporal table", + classOf[ValidationException]) + } + + @Test + def testJoinTemporalTable(): Unit = { + val sql = "SELECT * FROM MyTable AS T JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id" + + streamUtil.verifyPlan(sql) + } + + @Test + def testLeftJoinTemporalTable(): Unit = { + val sql = "SELECT * FROM MyTable AS T LEFT JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id" + + streamUtil.verifyPlan(sql) + } + + @Test + def testJoinTemporalTableWithNestedQuery(): Unit = { + val sql = "SELECT * FROM " + + "(SELECT a, b, proctime FROM MyTable WHERE c > 1000) AS T " + + "JOIN temporalTest " + + "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id" + + streamUtil.verifyPlan(sql) + } + + @Test + def testJoinTemporalTableWithProjectionPushDown(): Unit = { + val sql = + """ + |SELECT T.*, D.id + |FROM MyTable AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id + """.stripMargin + + streamUtil.verifyPlan(sql) + } + + @Test + def testJoinTemporalTableWithFilterPushDown(): Unit = { + val sql = + """ + |SELECT * FROM MyTable AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id AND D.age = 10 + |WHERE T.c > 1000 + """.stripMargin + + streamUtil.verifyPlan(sql) + } + + @Test + def testJoinTemporalTableWithCalcPushDown(): Unit = { + val sql = + """ + |SELECT * FROM MyTable AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id AND D.age = 10 + |WHERE cast(D.name as bigint) > 1000 + """.stripMargin + + streamUtil.verifyPlan(sql) + } + + @Test + def testJoinTemporalTableWithMultiIndexColumn(): Unit = { + val sql = + """ + |SELECT * FROM MyTable AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proctime AS D + |ON T.a = D.id AND D.age = 10 AND D.name = 'AAA' + |WHERE T.c > 1000 + """.stripMargin + + streamUtil.verifyPlan(sql) + } + + @Test + def testAvoidAggregatePushDown(): Unit = { + val sql1 = + """ + |SELECT b, a, sum(c) c, sum(d) d, PROCTIME() as proc + |FROM T1 + |GROUP BY a, b + """.stripMargin + + val sql2 = + s""" + |SELECT T.* FROM ($sql1) AS T + |JOIN temporalTest FOR SYSTEM_TIME AS OF T.proc AS D + |ON T.a = D.id + |WHERE D.age > 10 + """.stripMargin + + val sql = + s""" + |SELECT b, count(a), sum(c), sum(d) + |FROM ($sql2) AS T + |GROUP BY b + """.stripMargin + + streamUtil.verifyPlan(sql) + } + + // ========================================================================================== + + private def expectExceptionThrown( + sql: String, + keywords: String, + clazz: Class[_ <: Throwable] = classOf[ValidationException]) + : Unit = { + try { + streamUtil.tableEnv.toAppendStream[Row](streamUtil.tableEnv.sqlQuery(sql)) + fail(s"Expected a $clazz, but no exception is thrown.") + } catch { + case e if e.getClass == clazz => + if (keywords != null) { + assertTrue( + s"The actual exception message \n${e.getMessage}\n" + + s"doesn't contain expected keyword \n$keywords\n", + e.getMessage.contains(keywords)) + } + case e: Throwable => + e.printStackTrace() + fail(s"Expected throw ${clazz.getSimpleName}, but is $e.") + } + } + + private def verifyTranslationSuccess(sql: String): Unit = { + streamUtil.tableEnv.toAppendStream[Row](streamUtil.tableEnv.sqlQuery(sql)) + } +} + + +class TestTemporalTable + extends StreamTableSource[BaseRow] + with BatchTableSource[BaseRow] + with LookupableTableSource[BaseRow] + with DefinedIndexes { + + val fieldNames: Array[String] = Array("id", "name", "age") + val fieldTypes: Array[TypeInformation[_]] = Array(Types.INT, Types.STRING, Types.INT) + + override def getLookupFunction(lookupKeys: Array[String]): TableFunction[BaseRow] = { + throw new UnsupportedOperationException("This TableSource is only used for unit test, " + + "this method should never be called.") + } + + override def getAsyncLookupFunction(lookupKeys: Array[String]): AsyncTableFunction[BaseRow] = { + throw new UnsupportedOperationException("This TableSource is only used for unit test, " + + "this method should never be called.") + } + + override def getLookupConfig: LookupConfig = LookupConfig.DEFAULT + + override def getReturnType: TypeInformation[BaseRow] = { + new BaseRowTypeInfo( + Array(InternalTypes.INT, InternalTypes.STRING, InternalTypes.INT) + .asInstanceOf[Array[InternalType]], + fieldNames) + } + + override def getTableSchema: TableSchema = new TableSchema(fieldNames, fieldTypes) + + override def getIndexes: util.Collection[TableIndex] = { + val index1 = TableIndex.builder() + .normalIndex() + .indexedColumns("name") + .build() + val index2 = TableIndex.builder() + .uniqueIndex() + .indexedColumns("id") + .build() + util.Arrays.asList(index1, index2) + } + + override def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[BaseRow] = { + throw new UnsupportedOperationException("This TableSource is only used for unit test, " + + "this method should never be called.") + } + + override def getBoundedStream(streamEnv: StreamExecutionEnvironment): DataStream[BaseRow] = { + throw new UnsupportedOperationException("This TableSource is only used for unit test, " + + "this method should never be called.") + } +} + +class TestInvalidTemporalTable private( + async: Boolean, + fetcher: TableFunction[_], + asyncFetcher: AsyncTableFunction[_]) + extends StreamTableSource[BaseRow] + with LookupableTableSource[BaseRow] + with DefinedIndexes { + + val fieldNames: Array[String] = Array("id", "name", "age", "ts") + val fieldTypes: Array[TypeInformation[_]] = Array( + Types.INT, Types.STRING, Types.INT, Types.SQL_TIMESTAMP) + + def this(fetcher: TableFunction[_]) { + this(false, fetcher, null) + } + + def this(asyncFetcher: AsyncTableFunction[_]) { + this(true, null, asyncFetcher) + } + + override def getReturnType: TypeInformation[BaseRow] = { + new BaseRowTypeInfo( + Array(InternalTypes.INT, InternalTypes.STRING, InternalTypes.INT, InternalTypes.TIMESTAMP) + .asInstanceOf[Array[InternalType]], + fieldNames) + } + + override def getTableSchema: TableSchema = new TableSchema(fieldNames, fieldTypes) + + override def getLookupFunction(lookupKeys: Array[String]): TableFunction[BaseRow] = { + fetcher.asInstanceOf[TableFunction[BaseRow]] + } + + override def getAsyncLookupFunction(lookupKeys: Array[String]): AsyncTableFunction[BaseRow] = { + asyncFetcher.asInstanceOf[AsyncTableFunction[BaseRow]] + } + + override def getLookupConfig: LookupConfig = { + LookupConfig.builder() + .setAsyncEnabled(async) + .build() + } + + override def getIndexes: util.Collection[TableIndex] = { + util.Collections.singleton(TableIndex.builder() + .uniqueIndex() + .indexedColumns("id", "name", "ts") + .build()) + } + + override def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[BaseRow] = { + throw new UnsupportedOperationException("This TableSource is only used for unit test, " + + "this method should never be called.") + } + +} + +class InvalidTableFunctionResultType extends TableFunction[String] { + @varargs + def eval(obj: AnyRef*): Unit = { + } +} + +class InvalidTableFunctionEvalSignature1 extends TableFunction[BaseRow] { + def eval(a: Integer, b: String, c: Timestamp): Unit = { + } +} + +class ValidTableFunction extends TableFunction[BaseRow] { + @varargs + def eval(obj: AnyRef*): Unit = { + } +} + +class ValidTableFunction2 extends TableFunction[Row] { + def eval(a: Integer, b: String, c: Timestamp): Unit = { + } +} + +class InvalidAsyncTableFunctionResultType extends AsyncTableFunction[Row] { + @varargs + def eval(obj: AnyRef*): Unit = { + } +} + +class InvalidAsyncTableFunctionEvalSignature1 extends AsyncTableFunction[BaseRow] { + def eval(a: Integer, b: BinaryString, c: Timestamp): Unit = { + } +} + +class InvalidAsyncTableFunctionEvalSignature2 extends AsyncTableFunction[BaseRow] { + def eval(resultFuture: ResultFuture[BaseRow], a: Integer, b: String, c: Timestamp): Unit = { + } +} + +class ValidAsyncTableFunction extends AsyncTableFunction[BaseRow] { + @varargs + def eval(resultFuture: ResultFuture[BaseRow], objs: AnyRef*): Unit = { + } +} + +class ValidAsyncTableFunction2 extends AsyncTableFunction[BaseRow] { + def eval(resultFuture: ResultFuture[BaseRow], a: Integer, b: BinaryString, c: JLong): Unit = { + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/LookupJoinITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/LookupJoinITCase.scala new file mode 100644 index 0000000000000..fef65224c32ef --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/LookupJoinITCase.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.batch.sql.join + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.api.Types +import org.apache.flink.table.runtime.utils.{BatchTableEnvUtil, BatchTestBase, InMemoryLookupableTableSource} +import org.junit.{Before, Test} + +class LookupJoinITCase extends BatchTestBase { + + val data = List( + BatchTestBase.row(1L, 12L, "Julian"), + BatchTestBase.row(2L, 15L, "Hello"), + BatchTestBase.row(3L, 15L, "Fabian"), + BatchTestBase.row(8L, 11L, "Hello world"), + BatchTestBase.row(9L, 12L, "Hello world!")) + + val typeInfo = new RowTypeInfo(LONG_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO) + + val userData = List( + (11, 1L, "Julian"), + (22, 2L, "Jark"), + (33, 3L, "Fabian")) + + val userTableSource = InMemoryLookupableTableSource.builder() + .data(userData) + .field("age", Types.INT) + .field("id", Types.LONG) + .field("name", Types.STRING) + .primaryKey("id") + .build() + + val userAsyncTableSource = InMemoryLookupableTableSource.builder() + .data(userData) + .field("age", Types.INT) + .field("id", Types.LONG) + .field("name", Types.STRING) + .primaryKey("id") + .enableAsync() + .build() + + @Before + def setup() { + BatchTableEnvUtil.registerCollection(tEnv, "T0", data, typeInfo, "id, len, content") + val myTable = tEnv.sqlQuery("SELECT *, PROCTIME() as proctime FROM T0") + tEnv.registerTable("T", myTable) + tEnv.registerTableSource("userTable", userTableSource) + tEnv.registerTableSource("userAsyncTable", userAsyncTableSource) + } + + @Test + def testJoinTemporalTable(): Unit = { + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val expected = Seq( + BatchTestBase.row(1, 12, "Julian", "Julian"), + BatchTestBase.row(2, 15, "Hello", "Jark"), + BatchTestBase.row(3, 15, "Fabian", "Fabian")) + checkResult(sql, expected, false) + } + + @Test + def testJoinTemporalTableWithPushDown(): Unit = { + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id AND D.age > 20" + + val expected = Seq( + BatchTestBase.row(2, 15, "Hello", "Jark"), + BatchTestBase.row(3, 15, "Fabian", "Fabian")) + checkResult(sql, expected, false) + } + + @Test + def testJoinTemporalTableWithNonEqualFilter(): Unit = { + val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id WHERE T.len <= D.age" + + val expected = Seq( + BatchTestBase.row(2, 15, "Hello", "Jark", 22), + BatchTestBase.row(3, 15, "Fabian", "Fabian", 33)) + checkResult(sql, expected, false) + } + + @Test + def testJoinTemporalTableOnMultiFields(): Unit = { + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id AND T.content = D.name" + + val expected = Seq( + BatchTestBase.row(1, 12, "Julian"), + BatchTestBase.row(3, 15, "Fabian")) + checkResult(sql, expected, false) + } + + @Test + def testJoinTemporalTableOnMultiFieldsWithUdf(): Unit = { + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON mod(T.id, 4) = D.id AND T.content = D.name" + + val expected = Seq( + BatchTestBase.row(1, 12, "Julian"), + BatchTestBase.row(3, 15, "Fabian")) + checkResult(sql, expected, false) + } + + @Test + def testJoinTemporalTableOnMultiKeyFields(): Unit = { + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.content = D.name AND T.id = D.id" + + val expected = Seq( + BatchTestBase.row(1, 12, "Julian"), + BatchTestBase.row(3, 15, "Fabian")) + checkResult(sql, expected, false) + } + + @Test + def testLeftJoinTemporalTable(): Unit = { + val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val expected = Seq( + BatchTestBase.row(1, 12, "Julian", 11), + BatchTestBase.row(2, 15, "Jark", 22), + BatchTestBase.row(3, 15, "Fabian", 33), + BatchTestBase.row(8, 11, null, null), + BatchTestBase.row(9, 12, null, null)) + checkResult(sql, expected, false) + } + + @Test + def testAsyncJoinTemporalTable(): Unit = { + // TODO: enable object reuse until [FLINK-12351] is fixed. + env.getConfig.disableObjectReuse() + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userAsyncTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val expected = Seq( + BatchTestBase.row(1, 12, "Julian", "Julian"), + BatchTestBase.row(2, 15, "Hello", "Jark"), + BatchTestBase.row(3, 15, "Fabian", "Fabian")) + checkResult(sql, expected, false) + } + + @Test + def testAsyncJoinTemporalTableWithPushDown(): Unit = { + // TODO: enable object reuse until [FLINK-12351] is fixed. + env.getConfig.disableObjectReuse() + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userAsyncTable " + + "for system_time as of T.proctime AS D ON T.id = D.id AND D.age > 20" + + val expected = Seq( + BatchTestBase.row(2, 15, "Hello", "Jark"), + BatchTestBase.row(3, 15, "Fabian", "Fabian")) + checkResult(sql, expected, false) + } + + @Test + def testAsyncJoinTemporalTableWithNonEqualFilter(): Unit = { + // TODO: enable object reuse until [FLINK-12351] is fixed. + env.getConfig.disableObjectReuse() + val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN userAsyncTable " + + "for system_time as of T.proctime AS D ON T.id = D.id WHERE T.len <= D.age" + + val expected = Seq( + BatchTestBase.row(2, 15, "Hello", "Jark", 22), + BatchTestBase.row(3, 15, "Fabian", "Fabian", 33)) + checkResult(sql, expected, false) + } + + @Test + def testAsyncLeftJoinTemporalTableWithLocalPredicate(): Unit = { + // TODO: enable object reuse until [FLINK-12351] is fixed. + env.getConfig.disableObjectReuse() + val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T LEFT JOIN userAsyncTable " + + "for system_time as of T.proctime AS D ON T.id = D.id " + + "AND T.len > 1 AND D.age > 20 AND D.name = 'Fabian' " + + "WHERE T.id > 1" + + val expected = Seq( + BatchTestBase.row(2, 15, "Hello", null, null), + BatchTestBase.row(3, 15, "Fabian", "Fabian", 33), + BatchTestBase.row(8, 11, "Hello world", null, null), + BatchTestBase.row(9, 12, "Hello world!", null, null)) + checkResult(sql, expected, false) + } + + @Test + def testAsyncJoinTemporalTableOnMultiFields(): Unit = { + // TODO: enable object reuse until [FLINK-12351] is fixed. + env.getConfig.disableObjectReuse() + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userAsyncTable " + + "for system_time as of T.proctime AS D ON T.id = D.id AND T.content = D.name" + + val expected = Seq( + BatchTestBase.row(1, 12, "Julian"), + BatchTestBase.row(3, 15, "Fabian")) + checkResult(sql, expected, false) + } + + @Test + def testAsyncLeftJoinTemporalTable(): Unit = { + // TODO: enable object reuse until [FLINK-12351] is fixed. + env.getConfig.disableObjectReuse() + val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userAsyncTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val expected = Seq( + BatchTestBase.row(1, 12, "Julian", 11), + BatchTestBase.row(2, 15, "Jark", 22), + BatchTestBase.row(3, 15, "Fabian", 33), + BatchTestBase.row(8, 11, null, null), + BatchTestBase.row(9, 12, null, null)) + checkResult(sql, expected, false) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AsyncLookupJoinITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AsyncLookupJoinITCase.scala new file mode 100644 index 0000000000000..8575421c79143 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AsyncLookupJoinITCase.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.common.restartstrategy.RestartStrategies +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.Types +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.UserDefinedFunctionTestUtils._ +import org.apache.flink.table.runtime.utils.{InMemoryLookupableTableSource, StreamingWithStateTestBase, TestingAppendSink, TestingRetractSink} +import org.apache.flink.types.Row +import org.apache.flink.util.ExceptionUtils +import org.junit.Assert.{assertEquals, assertTrue, fail} +import org.junit.{Before, Test} +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(classOf[Parameterized]) +class AsyncLookupJoinITCase(backend: StateBackendMode) + extends StreamingWithStateTestBase(backend) { + + val data = List( + (1L, 12, "Julian"), + (2L, 15, "Hello"), + (3L, 15, "Fabian"), + (8L, 11, "Hello world"), + (9L, 12, "Hello world!")) + + val userData = List( + (11, 1L, "Julian"), + (22, 2L, "Jark"), + (33, 3L, "Fabian")) + + val userTableSource = InMemoryLookupableTableSource.builder() + .data(userData) + .field("age", Types.INT) + .field("id", Types.LONG) + .field("name", Types.STRING) + .primaryKey("id") + .enableAsync() + .build() + + val userTableSourceWith2Keys = InMemoryLookupableTableSource.builder() + .data(userData) + .field("age", Types.INT) + .field("id", Types.LONG) + .field("name", Types.STRING) + .addUniqueIndex("id", "name") + .enableAsync() + .build() + + + // TODO: remove this until [FLINK-12351] is fixed. + // currently AsyncWaitOperator doesn't copy input element which is a bug + @Before + override def before(): Unit = { + super.before() + env.getConfig.disableObjectReuse() + } + + @Test + def testAsyncJoinTemporalTableOnMultiKeyFields(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + // pk is (id: Long, name: String) + tEnv.registerTableSource("userTable", userTableSourceWith2Keys) + + // test left table's join key define order diffs from right's + val sql = + """ + |SELECT t1.id, t1.len, D.name + |FROM (select content, id, len, proctime FROM T) t1 + |JOIN userTable for system_time as of t1.proctime AS D + |ON t1.content = D.name AND t1.id = D.id + """.stripMargin + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian", + "3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSourceWith2Keys.getResourceCounter) + } + + @Test + def testAsyncJoinTemporalTable(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian,Julian", + "2,15,Hello,Jark", + "3,15,Fabian,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testAsyncJoinTemporalTableWithPushDown(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id AND D.age > 20" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "2,15,Hello,Jark", + "3,15,Fabian,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testAsyncJoinTemporalTableWithNonEqualFilter(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id WHERE T.len <= D.age" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "2,15,Hello,Jark,22", + "3,15,Fabian,Fabian,33") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testAsyncLeftJoinTemporalTableWithLocalPredicate(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T LEFT JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id " + + "AND T.len > 1 AND D.age > 20 AND D.name = 'Fabian' " + + "WHERE T.id > 1" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "2,15,Hello,null,null", + "3,15,Fabian,Fabian,33", + "8,11,Hello world,null,null", + "9,12,Hello world!,null,null") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testAsyncJoinTemporalTableOnMultiFields(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id AND T.content = D.name" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian", + "3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testAsyncJoinTemporalTableOnMultiFieldsWithUdf(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + tEnv.registerFunction("mod1", TestMod) + tEnv.registerFunction("wrapper1", TestWrapperUdf) + + val sql = "SELECT T.id, T.len, wrapper1(D.name) as name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D " + + "ON mod1(T.id, 4) = D.id AND T.content = D.name" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian", + "3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testAsyncJoinTemporalTableWithUdfFilter(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + tEnv.registerFunction("add", new TestAddWithOpen) + + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id " + + "WHERE add(T.id, D.id) > 3 AND add(T.id, 2) > 3 AND add (D.id, 2) > 3" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "2,15,Hello,Jark", + "3,15,Fabian,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + assertEquals(0, TestAddWithOpen.aliveCounter.get()) + } + + @Test + def testAggAndAsyncLeftJoinTemporalTable(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql1 = "SELECT max(id) as id, PROCTIME() as proctime from T group by len" + + val table1 = tEnv.sqlQuery(sql1) + tEnv.registerTable("t1", table1) + + val sql2 = "SELECT t1.id, D.name, D.age FROM t1 LEFT JOIN userTable " + + "for system_time as of t1.proctime AS D ON t1.id = D.id" + + val sink = new TestingRetractSink + tEnv.sqlQuery(sql2).toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + + val expected = Seq( + "3,Fabian,33", + "8,null,null", + "9,null,null") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + + + @Test + def testAsyncLeftJoinTemporalTable(): Unit = { + val streamTable = failingDataSource(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian,11", + "2,15,Jark,22", + "3,15,Fabian,33", + "8,11,null,null", + "9,12,null,null") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testExceptionThrownFromAsyncJoinTemporalTable(): Unit = { + env.setRestartStrategy(RestartStrategies.noRestart()) + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + tEnv.registerFunction("errorFunc", TestExceptionThrown) + + val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id " + + "where errorFunc(D.name) > cast(1000 as decimal(10,4))" // should exception here + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + + try { + env.execute() + } catch { + case t: Throwable => + val exception = ExceptionUtils.findThrowable(t, classOf[NumberFormatException]) + assertTrue(exception.isPresent) + assertTrue(exception.get().getMessage.contains("Cannot parse")) + return + } + fail("NumberFormatException is expected here!") + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/LookupJoinITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/LookupJoinITCase.scala new file mode 100644 index 0000000000000..b55b3321bcc38 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/LookupJoinITCase.scala @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.Types +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.runtime.utils.UserDefinedFunctionTestUtils.TestAddWithOpen +import org.apache.flink.table.runtime.utils.{InMemoryLookupableTableSource, StreamingTestBase, TestingAppendSink} +import org.apache.flink.types.Row +import org.junit.Assert.assertEquals +import org.junit.Test + +import java.lang.{Integer => JInt, Long => JLong} + +class LookupJoinITCase extends StreamingTestBase { + + val data = List( + (1L, 12, "Julian"), + (2L, 15, "Hello"), + (3L, 15, "Fabian"), + (8L, 11, "Hello world"), + (9L, 12, "Hello world!")) + + val dataWithNull = List( + Row.of(null, new JInt(15), "Hello"), + Row.of(new JLong(3), new JInt(15), "Fabian"), + Row.of(null, new JInt(11), "Hello world"), + Row.of(new JLong(9), new JInt(12), "Hello world!")) + + val userData = List( + (11, 1L, "Julian"), + (22, 2L, "Jark"), + (33, 3L, "Fabian")) + + val userTableSource = InMemoryLookupableTableSource.builder() + .data(userData) + .field("age", Types.INT) + .field("id", Types.LONG) + .field("name", Types.STRING) + .primaryKey("id") + .build() + + val userTableSourceWith2Keys = InMemoryLookupableTableSource.builder() + .data(userData) + .field("age", Types.INT) + .field("id", Types.LONG) + .field("name", Types.STRING) + .addUniqueIndex("id", "name") + .build() + + @Test + def testJoinTemporalTable(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian,Julian", + "2,15,Hello,Jark", + "3,15,Fabian,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testJoinTemporalTableWithUdfFilter(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + tEnv.registerFunction("add", new TestAddWithOpen) + + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id " + + "WHERE add(T.id, D.id) > 3 AND add(T.id, 2) > 3 AND add (D.id, 2) > 3" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "2,15,Hello,Jark", + "3,15,Fabian,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + assertEquals(0, TestAddWithOpen.aliveCounter.get()) + } + + @Test + def testJoinTemporalTableOnConstantKey(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON D.id = 1" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian,Julian", "2,15,Hello,Julian", "3,15,Fabian,Julian", + "8,11,Hello world,Julian", "9,12,Hello world!,Julian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testJoinTemporalTableOnNullableKey(): Unit = { + + implicit val tpe: TypeInformation[Row] = new RowTypeInfo( + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO) + val streamTable = env.fromCollection(dataWithNull) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq("3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testJoinTemporalTableWithPushDown(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id AND D.age > 20" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "2,15,Hello,Jark", + "3,15,Fabian,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testJoinTemporalTableWithNonEqualFilter(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id WHERE T.len <= D.age" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "2,15,Hello,Jark,22", + "3,15,Fabian,Fabian,33") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testJoinTemporalTableOnMultiFields(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id AND T.content = D.name" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian", + "3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testJoinTemporalTableOnMultiKeyFields(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSourceWith2Keys) + + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.content = D.name AND T.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian", + "3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSourceWith2Keys.getResourceCounter) + } + + @Test + def testJoinTemporalTableOnMultiKeyFields2(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + // pk is (id: Int, name: String) + tEnv.registerTableSource("userTable", userTableSourceWith2Keys) + + // test left table's join key define order diffs from right's + val sql = "SELECT t1.id, t1.len, D.name FROM (select proctime, content, id, len FROM T) t1 " + + "JOIN userTable for system_time as of t1.proctime AS D " + + "ON t1.content = D.name AND t1.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian", + "3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSourceWith2Keys.getResourceCounter) + } + + @Test + def testJoinTemporalTableOnMultiKeyFieldsWithConstantKey(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSourceWith2Keys) + + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON T.content = D.name AND 3 = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq("3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSourceWith2Keys.getResourceCounter) + } + + @Test + def testJoinTemporalTableOnMultiKeyFieldsWithStringConstantKey(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSourceWith2Keys) + + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON D.name = 'Fabian' AND T.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq("3,15,Fabian") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSourceWith2Keys.getResourceCounter) + } + + @Test + def testJoinTemporalTableOnMultiConstantKey(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSourceWith2Keys) + + val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " + + "for system_time as of T.proctime AS D ON D.name = 'Fabian' AND 3 = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Fabian", + "2,15,Fabian", + "3,15,Fabian", + "8,11,Fabian", + "9,12,Fabian" + ) + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSourceWith2Keys.getResourceCounter) + } + + @Test + def testLeftJoinTemporalTable(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian,11", + "2,15,Jark,22", + "3,15,Fabian,33", + "8,11,null,null", + "9,12,null,null") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testLeftJoinTemporalTableOnNullableKey(): Unit = { + + implicit val tpe: TypeInformation[Row] = new RowTypeInfo( + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO) + val streamTable = env.fromCollection(dataWithNull) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, D.name FROM T LEFT OUTER JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "null,15,null", + "3,15,Fabian", + "null,11,null", + "9,12,null") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + + @Test + def testLeftJoinTemporalTableOnMultKeyFields(): Unit = { + val streamTable = env.fromCollection(data) + .toTable(tEnv, 'id, 'len, 'content, 'proctime) + tEnv.registerTable("T", streamTable) + + tEnv.registerTableSource("userTable", userTableSource) + + val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userTable " + + "for system_time as of T.proctime AS D ON T.id = D.id and T.content = D.name" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = Seq( + "1,12,Julian,11", + "2,15,null,null", + "3,15,Fabian,33", + "8,11,null,null", + "9,12,null,null") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + assertEquals(0, userTableSource.getResourceCounter) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/InMemoryLookupableTableSource.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/InMemoryLookupableTableSource.scala new file mode 100644 index 0000000000000..21cf7f82e6b51 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/InMemoryLookupableTableSource.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils + +import org.apache.flink.annotation.VisibleForTesting +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.streaming.api.functions.async.ResultFuture +import org.apache.flink.table.api.TableSchema +import org.apache.flink.table.functions.{AsyncTableFunction, FunctionContext, TableFunction} +import org.apache.flink.table.runtime.utils.InMemoryLookupableTableSource.{InMemoryAsyncLookupFunction, InMemoryLookupFunction} +import org.apache.flink.table.sources.TableIndex.IndexType +import org.apache.flink.table.sources._ +import org.apache.flink.types.Row +import org.apache.flink.util.Preconditions + +import java.util +import java.util.Collections +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{CompletableFuture, ExecutorService, Executors} +import java.util.function.{Consumer, Supplier} + +import scala.annotation.varargs +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * A [[LookupableTableSource]] which stores table in memory, this is mainly used for testing. + */ +class InMemoryLookupableTableSource( + fieldNames: Array[String], + fieldTypes: Array[TypeInformation[_]], + data: List[Row], + primaryKey: Option[Array[String]], + tableIndexes: Array[TableIndex], + lookupConfig: LookupConfig) + extends LookupableTableSource[Row] + with StreamTableSource[Row] + with BatchTableSource[Row] + with DefinedPrimaryKey + with DefinedIndexes { + + lazy val uniqueKeys: Array[Array[String]] = { + val keys = new mutable.ArrayBuffer[Array[String]]() + if (getPrimaryKeyColumns != null) { + keys += getPrimaryKeyColumns.asScala.toArray + } + getIndexes.asScala + .filter(_.getIndexType == IndexType.UNIQUE) + .foreach(keys += _.getIndexedColumns.asScala.toArray) + keys.toArray + } + + val resourceCounter = new AtomicInteger(0) + + override def getLookupFunction(lookupKeys: Array[String]): TableFunction[Row] = { + new InMemoryLookupFunction(convertDataToMap(lookupKeys), resourceCounter) + } + + override def getAsyncLookupFunction(lookupKeys: Array[String]): AsyncTableFunction[Row] = { + new InMemoryAsyncLookupFunction(convertDataToMap(lookupKeys), resourceCounter) + } + + private def convertDataToMap(lookupKeys: Array[String]): Map[Row, List[Row]] = { + val isUniqueKey = uniqueKeys.contains(lookupKeys) + val lookupFieldIndexes = lookupKeys.map(fieldNames.indexOf(_)) + val map = mutable.HashMap[Row, List[Row]]() + if (isUniqueKey) { + data.foreach { row => + val key = Row.of(lookupFieldIndexes.map(row.getField): _*) + val oldValue = map.put(key, List(row)) + if (oldValue.isDefined) { + throw new IllegalStateException("data contains duplicate keys.") + } + } + } else { + data.foreach { row => + val key = Row.of(lookupFieldIndexes.map(row.getField): _*) + val oldValue = map.get(key) + if (oldValue.isDefined) { + map.put(key, oldValue.get ++ List(row)) + } else { + map.put(key, List(row)) + } + } + } + map.toMap + } + + override def getLookupConfig: LookupConfig = lookupConfig + + override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames) + + override def getTableSchema: TableSchema = new TableSchema(fieldNames, fieldTypes) + + override def getPrimaryKeyColumns: util.List[String] = primaryKey match { + case Some(pk) => pk.toList.asJava + case None => null // return null to indicate no primary key is defined. + } + + override def getIndexes: util.Collection[TableIndex] = tableIndexes.toList.asJava + + @VisibleForTesting + def getResourceCounter: Int = resourceCounter.get() + + override def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = { + throw new UnsupportedOperationException("This should never be called.") + } + + override def getBoundedStream(streamEnv: StreamExecutionEnvironment): DataStream[Row] = { + throw new UnsupportedOperationException("This should never be called.") + } +} + +object InMemoryLookupableTableSource { + + /** + * Return a new builder that builds a [[InMemoryLookupableTableSource]]. + * + * For example: + * + * {{{ + * val data = ( + * (11, 1L, "Julian"), + * (22, 2L, "Jark"), + * (33, 3L, "Fabian")) + * + * val source = InMemoryLookupableTableSource.builder() + * .data(data) + * .field("age", Types.INT) + * .field("id", Types.LONG) + * .field("name", Types.STRING) + * .primaryKey("id") + * .addNormalIndex("name") + * .enableAsync() + * .build() + * }}} + * + * @return a new builder to build a [[InMemoryLookupableTableSource]] + */ + def builder(): Builder = new Builder + + + /** + * A builder for creating [[InMemoryLookupableTableSource]] instances. + * + * For example: + * + * {{{ + * val data = ( + * (11, 1L, "Julian"), + * (22, 2L, "Jark"), + * (33, 3L, "Fabian")) + * + * val source = InMemoryLookupableTableSource.builder() + * .data(data) + * .field("age", Types.INT) + * .field("id", Types.LONG) + * .field("name", Types.STRING) + * .primaryKey("id") + * .addNormalIndex("name") + * .enableAsync() + * .build() + * }}} + */ + class Builder { + private val schema = new mutable.LinkedHashMap[String, TypeInformation[_]]() + private val tableIndexes = new mutable.ArrayBuffer[TableIndex]() + private var primaryKey: Option[Array[String]] = None + private var data: List[Product] = _ + private val lookupConfigBuilder: LookupConfig.Builder = LookupConfig.builder() + + /** + * Sets table data for the table source. + */ + def data(data: List[Product]): Builder = { + this.data = data + this + } + + /** + * Adds a field with the field name and the type information. Required. + * This method can be called multiple times. The call order of this method defines + * also the order of the fields in a row. + * + * @param fieldName the field name + * @param fieldType the type information of the field + */ + def field(fieldName: String, fieldType: TypeInformation[_]): Builder = { + if (schema.contains(fieldName)) { + throw new IllegalArgumentException(s"Duplicate field name $fieldName.") + } + schema += (fieldName -> fieldType) + this + } + + /** + * Sets primary key for the table source. + */ + def primaryKey(fields: String*): Builder = { + if (fields.isEmpty) { + throw new IllegalArgumentException("fields should not be empty.") + } + if (primaryKey != null && primaryKey.isDefined) { + throw new IllegalArgumentException("primary key has been set.") + } + this.primaryKey = Some(fields.toArray) + this + } + + /** + * Adds a normal [[TableIndex]] for the table source + */ + def addNormalIndex(fields: String*): Builder = { + if (fields.isEmpty) { + throw new IllegalArgumentException("fields should not be empty.") + } + val index = TableIndex.builder() + .normalIndex() + .indexedColumns(fields: _*) + .build() + tableIndexes += index + this + } + + /** + * Adds an unique [[TableIndex]] for the table source + */ + def addUniqueIndex(fields: String*): Builder = { + if (fields.isEmpty) { + throw new IllegalArgumentException("fields should not be empty.") + } + val index = TableIndex.builder() + .uniqueIndex() + .indexedColumns(fields: _*) + .build() + tableIndexes += index + this + } + + /** + * Enables async lookup for the table source + */ + def enableAsync(): Builder = { + lookupConfigBuilder.setAsyncEnabled(true) + this + } + + /** + * Sets async buffer capacity. + */ + def asyncBufferCapacity(capacity: Int): Builder = { + lookupConfigBuilder.setAsyncBufferCapacity(capacity) + this + } + + /** + * Sets async time out milli-second. + */ + def asyncTimeoutMs(ms: Long): Builder = { + lookupConfigBuilder.setAsyncTimeoutMs(ms) + this + } + + /** + * Apply the current values and constructs a newly-created [[InMemoryLookupableTableSource]]. + * + * @return a newly-created [[InMemoryLookupableTableSource]]. + */ + def build(): InMemoryLookupableTableSource = { + val fieldNames = schema.keys.toArray + val fieldTypes = schema.values.toArray + Preconditions.checkNotNull(data) + // convert + val rowData = data.map { entry => + Row.of((0 until entry.productArity).map(entry.productElement(_).asInstanceOf[Object]): _*) + } + new InMemoryLookupableTableSource( + fieldNames, + fieldTypes, + rowData, + primaryKey, + tableIndexes.toArray, + lookupConfigBuilder.build() + ) + } + } + + /** + * A lookup function which find matched rows with the given fields. + */ + private class InMemoryLookupFunction( + data: Map[Row, List[Row]], + resourceCounter: AtomicInteger) + extends TableFunction[Row] { + + override def open(context: FunctionContext): Unit = { + resourceCounter.incrementAndGet() + } + + @varargs + def eval(inputs: AnyRef*): Unit = { + val key = Row.of(inputs: _*) + data.get(key) match { + case Some(list) => list.foreach(result => collect(result)) + case None => // do nothing + } + } + + override def close(): Unit = { + resourceCounter.decrementAndGet() + } + } + + /** + * An async lookup function which find matched rows with the given fields. + */ + private class InMemoryAsyncLookupFunction( + data: Map[Row, List[Row]], + resourceCounter: AtomicInteger, + delayedReturn: Int = 0) + extends AsyncTableFunction[Row] { + + @transient + var executor: ExecutorService = _ + + override def open(context: FunctionContext): Unit = { + resourceCounter.incrementAndGet() + executor = Executors.newSingleThreadExecutor() + } + + @varargs + def eval(resultFuture: ResultFuture[Row], inputs: AnyRef*): Unit = { + CompletableFuture + .supplyAsync(new CollectionSupplier(data, Row.of(inputs: _*)), executor) + .thenAccept(new CollectionConsumer(resultFuture)) + } + + override def close(): Unit = { + resourceCounter.decrementAndGet() + if (null != executor && !executor.isShutdown) { + executor.shutdown() + } + } + + private class CollectionSupplier(data: Map[Row, List[Row]], key: Row) + extends Supplier[util.Collection[Row]] { + + override def get(): util.Collection[Row] = { + val list = data.get(key) + if (list.isDefined && list.get.nonEmpty) { + list.get.asJavaCollection + } else { + Collections.emptyList() + } + } + } + + private class CollectionConsumer(resultFuture: ResultFuture[Row]) + extends Consumer[util.Collection[Row]] { + + override def accept(results: util.Collection[Row]): Unit = { + resultFuture.complete(results) + } + } + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/UserDefinedFunctionTestUtils.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/UserDefinedFunctionTestUtils.scala index 117c725b80561..f40b9f417d64d 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/UserDefinedFunctionTestUtils.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/UserDefinedFunctionTestUtils.scala @@ -23,7 +23,7 @@ import org.apache.flink.api.java.tuple.{Tuple1, Tuple2} import org.apache.flink.api.scala.ExecutionEnvironment import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment -import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction} +import org.apache.flink.table.functions.{AggregateFunction, FunctionContext, ScalarFunction} import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} @@ -34,6 +34,7 @@ import org.apache.flink.types.Row import java.io.File import java.util +import java.util.concurrent.atomic.AtomicInteger import scala.annotation.varargs @@ -259,6 +260,59 @@ object UserDefinedFunctionTestUtils { } } + object TestWrapperUdf extends ScalarFunction { + def eval(id: Int): Int = { + id + } + + def eval(id: String): String = { + id + } + } + + class TestAddWithOpen extends ScalarFunction { + + var isOpened: Boolean = false + + override def open(context: FunctionContext): Unit = { + super.open(context) + isOpened = true + TestAddWithOpen.aliveCounter.incrementAndGet() + } + + def eval(a: Long, b: Long): Long = { + if (!isOpened) { + throw new IllegalStateException("Open method is not called.") + } + a + b + } + + def eval(a: Long, b: Int): Long = { + eval(a, b.asInstanceOf[Long]) + } + + override def close(): Unit = { + TestAddWithOpen.aliveCounter.decrementAndGet() + } + } + + object TestAddWithOpen { + /** A thread-safe counter to record how many alive TestAddWithOpen UDFs */ + val aliveCounter = new AtomicInteger(0) + } + + object TestMod extends ScalarFunction { + def eval(src: Long, mod: Int): Long = { + src % mod + } + } + + object TestExceptionThrown extends ScalarFunction { + def eval(src: String): Int = { + throw new NumberFormatException("Cannot parse this input.") + } + } + // ------------------------------------------------------------------------------------ // POJOs // ------------------------------------------------------------------------------------ diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/TableTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/TableTestBase.scala index 502767b15ef4c..f766135e10ff8 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/TableTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/TableTestBase.scala @@ -36,10 +36,10 @@ import org.apache.flink.table.plan.optimize.program.{FlinkBatchProgram, FlinkStr import org.apache.flink.table.plan.util.{ExecNodePlanDumper, FlinkRelOptUtil} import org.apache.flink.table.sources.{BatchTableSource, StreamTableSource} import org.apache.flink.table.typeutils.BaseRowTypeInfo - import org.apache.calcite.rel.RelNode import org.apache.calcite.sql.SqlExplainLevel import org.apache.commons.lang3.SystemUtils +import org.apache.flink.table.runtime.utils.BatchTableEnvUtil import org.junit.Assert.{assertEquals, assertTrue} import org.junit.Rule import org.junit.rules.{ExpectedException, TestName} @@ -457,10 +457,16 @@ case class BatchTableTestUtil(test: TableTestBase) extends TableTestUtil(test) { override def getTableEnv: TableEnvironment = tableEnv - // TODO implements this method when a DataStream could be converted into a Table override def addDataStream[T: TypeInformation]( name: String, fields: Symbol*): Table = { - throw new TableException("Implements this") + val typeInfo = implicitly[TypeInformation[T]] + BatchTableEnvUtil.registerCollection( + tableEnv, + name, + Seq(), + typeInfo, + fields.map(_.name).mkString(", ")) + tableEnv.scan(name) } def buildBatchProgram(firstProgramNameToRemove: String): Unit = { diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedClass.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedClass.java index 903943a83d950..92af598a85a2c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedClass.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedClass.java @@ -69,7 +69,10 @@ public T newInstance(ClassLoader classLoader, Object... args) { } } - private Class compile(ClassLoader classLoader) { + /** + * Compiles the generated code, the compiled class will be cached in the {@link GeneratedClass}. + */ + public Class compile(ClassLoader classLoader) { if (compiledClass == null) { // cache the compiled class compiledClass = CompileUtils.compile(classLoader, className, code); diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedCollector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedCollector.java index ce082ba655ab7..3e36d124c449c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedCollector.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedCollector.java @@ -25,7 +25,7 @@ * * @param type of collector */ -public final class GeneratedCollector> extends GeneratedClass { +public class GeneratedCollector> extends GeneratedClass { private static final long serialVersionUID = -7355875544905245676L; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedFunction.java index e43da42ef7085..69163f7db37a7 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedFunction.java @@ -25,7 +25,7 @@ * * @param type of Function */ -public final class GeneratedFunction extends GeneratedClass { +public class GeneratedFunction extends GeneratedClass { private static final long serialVersionUID = -7355875544905245676L; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedHashFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedHashFunction.java index 204907491cb12..914d0bcfdd087 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedHashFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedHashFunction.java @@ -21,7 +21,7 @@ /** * Describes a generated {@link HashFunction}. */ -public final class GeneratedHashFunction extends GeneratedClass { +public class GeneratedHashFunction extends GeneratedClass { private static final long serialVersionUID = 1L; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedInput.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedInput.java index 9331bb3a4d960..96ddc4a2e36eb 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedInput.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedInput.java @@ -25,7 +25,7 @@ * * @param type of Function */ -public final class GeneratedInput> extends GeneratedClass { +public class GeneratedInput> extends GeneratedClass { private static final long serialVersionUID = -7355875544905245676L; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedNamespaceAggsHandleFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedNamespaceAggsHandleFunction.java index fef65e9c0803b..5a9ef8beed20c 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedNamespaceAggsHandleFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedNamespaceAggsHandleFunction.java @@ -21,7 +21,7 @@ /** * Describes a generated {@link NamespaceAggsHandleFunction}. */ -public final class GeneratedNamespaceAggsHandleFunction +public class GeneratedNamespaceAggsHandleFunction extends GeneratedClass> { private static final long serialVersionUID = 1L; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedOperator.java index e2209a2a5f882..e0316ed42c0e1 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedOperator.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedOperator.java @@ -25,7 +25,7 @@ * * @param type of StreamOperator */ -public final class GeneratedOperator> extends GeneratedClass { +public class GeneratedOperator> extends GeneratedClass { private static final long serialVersionUID = -7355875544905245676L; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedResultFuture.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedResultFuture.java new file mode 100644 index 0000000000000..e79142688dc64 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedResultFuture.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.generated; + +import org.apache.flink.streaming.api.functions.async.ResultFuture; + +/** + * Describes a generated {@link ResultFuture}. + * + * @param type of ResultFuture + */ +public class GeneratedResultFuture> extends GeneratedClass { + + private static final long serialVersionUID = -7355875544905245676L; + + /** + * Creates a GeneratedResultFuture. + * + * @param className class name of the generated ResultFuture. + * @param code code of the generated ResultFuture. + * @param references referenced objects of the generated ResultFuture. + */ + public GeneratedResultFuture(String className, String code, Object[] references) { + super(className, code, references); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/collector/TableFunctionCollector.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/collector/TableFunctionCollector.java new file mode 100644 index 0000000000000..a96f40d10ef32 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/collector/TableFunctionCollector.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.collector; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.table.functions.TableFunction; +import org.apache.flink.util.Collector; + +/** + * The basic implementation of collector for {@link TableFunction}. + */ +public abstract class TableFunctionCollector extends AbstractRichFunction implements Collector { + + private static final long serialVersionUID = 1L; + + private Object input; + private Collector collector; + private boolean collected; + + /** + * Sets the input row from left table, + * which will be used to cross join with the result of table function. + */ + public void setInput(Object input) { + this.input = input; + } + + /** + * Gets the input value from left table, + * which will be used to cross join with the result of table function. + */ + public Object getInput() { + return input; + } + + /** + * Sets the current collector, which used to emit the final row. + */ + public void setCollector(Collector collector) { + this.collector = collector; + } + + /** + * Resets the flag to indicate whether [[collect(T)]] has been called. + */ + public void reset() { + this.collected = false; + if (collector instanceof TableFunctionCollector) { + ((TableFunctionCollector) collector).reset(); + } + } + + /** + * Output final result of this UDTF to downstreams. + */ + @SuppressWarnings("unchecked") + public void outputResult(Object result) { + this.collected = true; + this.collector.collect(result); + } + + /** + * Whether {@link #collect(Object)} has been called. + * + * @return True if {@link #collect(Object)} has been called. + */ + public boolean isCollected() { + return collected; + } + + public void close() { + this.collector.close(); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/collector/TableFunctionResultFuture.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/collector/TableFunctionResultFuture.java new file mode 100644 index 0000000000000..a472f9ceb5b3d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/collector/TableFunctionResultFuture.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.collector; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.streaming.api.functions.async.ResultFuture; + +/** + * The basic implementation of collector for {@link ResultFuture} in table joining. + */ +public abstract class TableFunctionResultFuture extends AbstractRichFunction implements ResultFuture { + + private static final long serialVersionUID = 1L; + + private Object input; + private ResultFuture resultFuture; + + /** + * Sets the input row from left table, + * which will be used to cross join with the result of right table. + */ + public void setInput(Object input) { + this.input = input; + } + + /** + * Gets the input value from left table, + * which will be used to cross join with the result of right table. + */ + public Object getInput() { + return input; + } + + /** + * Sets the current collector, which used to emit the final row. + */ + public void setResultFuture(ResultFuture resultFuture) { + this.resultFuture = resultFuture; + } + + /** + * Gets the internal collector which used to emit the final row. + */ + public ResultFuture getResultFuture() { + return this.resultFuture; + } + + @Override + public void completeExceptionally(Throwable error) { + this.resultFuture.completeExceptionally(error); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/AsyncLookupJoinRunner.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/AsyncLookupJoinRunner.java new file mode 100644 index 0000000000000..4a15671e15691 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/AsyncLookupJoinRunner.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join.lookup; + +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.functions.async.RichAsyncFunction; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.DataFormatConverters; +import org.apache.flink.table.dataformat.DataFormatConverters.RowConverter; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.generated.GeneratedFunction; +import org.apache.flink.table.generated.GeneratedResultFuture; +import org.apache.flink.table.runtime.collector.TableFunctionResultFuture; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.types.Row; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +/** + * The async join runner to lookup the dimension table. + */ +public class AsyncLookupJoinRunner extends RichAsyncFunction { + private static final long serialVersionUID = -6664660022391632480L; + + private final GeneratedFunction> generatedFetcher; + private final GeneratedResultFuture> generatedResultFuture; + private final boolean isLeftOuterJoin; + private final int asyncBufferCapacity; + private final TypeInformation fetcherReturnType; + private final BaseRowTypeInfo rightRowTypeInfo; + + private transient AsyncFunction fetcher; + + /** + * Buffers {@link ResultFuture} to avoid newInstance cost when processing elements every time. + * We use {@link BlockingQueue} to make sure the head {@link ResultFuture}s are available. + */ + private transient BlockingQueue resultFutureBuffer; + + public AsyncLookupJoinRunner( + GeneratedFunction> generatedFetcher, + GeneratedResultFuture> generatedResultFuture, + TypeInformation fetcherReturnType, + BaseRowTypeInfo rightRowTypeInfo, + boolean isLeftOuterJoin, + int asyncBufferCapacity) { + this.generatedFetcher = generatedFetcher; + this.generatedResultFuture = generatedResultFuture; + this.isLeftOuterJoin = isLeftOuterJoin; + this.asyncBufferCapacity = asyncBufferCapacity; + this.fetcherReturnType = fetcherReturnType; + this.rightRowTypeInfo = rightRowTypeInfo; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + this.fetcher = generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader()); + FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext()); + FunctionUtils.openFunction(fetcher, parameters); + + // try to compile the generated ResultFuture, fail fast if the code is corrupt. + generatedResultFuture.compile(getRuntimeContext().getUserCodeClassLoader()); + + // row converter is stateless which is thread-safe + RowConverter rowConverter; + if (fetcherReturnType instanceof RowTypeInfo) { + rowConverter = (RowConverter) DataFormatConverters.getConverterForTypeInfo(fetcherReturnType); + } else if (fetcherReturnType instanceof BaseRowTypeInfo) { + rowConverter = null; + } else { + throw new IllegalStateException("This should never happen, " + + "currently fetcherReturnType can only be BaseRowTypeInfo or RowTypeInfo"); + } + + // asyncBufferCapacity + 1 as the queue size in order to avoid + // blocking on the queue when taking a collector. + this.resultFutureBuffer = new ArrayBlockingQueue<>(asyncBufferCapacity + 1); + for (int i = 0; i < asyncBufferCapacity + 1; i++) { + JoinedRowResultFuture rf = new JoinedRowResultFuture( + resultFutureBuffer, + createFetcherResultFuture(parameters), + rowConverter, + isLeftOuterJoin, + rightRowTypeInfo.getArity()); + // add will throw exception immediately if the queue is full which should never happen + resultFutureBuffer.add(rf); + } + } + + @Override + public void asyncInvoke(BaseRow input, ResultFuture resultFuture) throws Exception { + JoinedRowResultFuture outResultFuture = resultFutureBuffer.take(); + // the input row is copied when object reuse in AsyncWaitOperator + outResultFuture.reset(input, resultFuture); + + // fetcher has copied the input field when object reuse is enabled + fetcher.asyncInvoke(input, outResultFuture); + } + + public TableFunctionResultFuture createFetcherResultFuture(Configuration parameters) throws Exception { + TableFunctionResultFuture resultFuture = generatedResultFuture.newInstance( + getRuntimeContext().getUserCodeClassLoader()); + FunctionUtils.setFunctionRuntimeContext(resultFuture, getRuntimeContext()); + FunctionUtils.openFunction(resultFuture, parameters); + return resultFuture; + } + + @Override + public void close() throws Exception { + super.close(); + if (fetcher != null) { + FunctionUtils.closeFunction(fetcher); + } + for (JoinedRowResultFuture rf : resultFutureBuffer) { + rf.close(); + } + } + + /** + * The {@link JoinedRowResultFuture} is used to combine left {@link BaseRow} and + * right {@link BaseRow} into {@link JoinedRow}. + * + *

There are 3 phases in this collector. + * + *

    + *
  1. accept lookup function return result and convert it into BaseRow, call it right result
  2. + *
  3. project & filter the right result if there is a calc on the temporal table, + * see {@link AsyncLookupJoinWithCalcRunner#createFetcherResultFuture(Configuration)}
  4. + *
  5. filter the result if a join condition exist, + * see {@link AsyncLookupJoinRunner#createFetcherResultFuture(Configuration)}
  6. + *
  7. combine left input and the right result into a JoinedRow, call it join result
  8. + *
+ * + *

TODO: code generate a whole JoinedRowResultFuture in the future + */ + private static final class JoinedRowResultFuture implements ResultFuture { + + private final BlockingQueue resultFutureBuffer; + private final TableFunctionResultFuture joinConditionResultFuture; + private final RowConverter rowConverter; + private final boolean isLeftOuterJoin; + + private final DelegateResultFuture delegate; + private final GenericRow nullRow; + + private BaseRow leftRow; + private ResultFuture realOutput; + + private JoinedRowResultFuture( + BlockingQueue resultFutureBuffer, + TableFunctionResultFuture joinConditionResultFuture, + @Nullable RowConverter rowConverter, + boolean isLeftOuterJoin, + int rightArity) { + this.resultFutureBuffer = resultFutureBuffer; + this.joinConditionResultFuture = joinConditionResultFuture; + this.rowConverter = rowConverter; + this.isLeftOuterJoin = isLeftOuterJoin; + this.delegate = new DelegateResultFuture(); + this.nullRow = new GenericRow(rightArity); + } + + public void reset(BaseRow row, ResultFuture realOutput) { + this.realOutput = realOutput; + this.leftRow = row; + joinConditionResultFuture.setInput(row); + joinConditionResultFuture.setResultFuture(delegate); + delegate.reset(); + } + + @Override + public void complete(Collection result) { + Collection baseRows; + if (rowConverter == null) { + // result is BaseRow Collection + //noinspection unchecked + baseRows = (Collection) result; + } else { + baseRows = new ArrayList<>(result.size()); + for (Object element : result) { + Row row = (Row) element; + baseRows.add(rowConverter.toInternal(row)); + } + } + + // call condition collector first, + // the filtered result will be routed to the delegateCollector + try { + joinConditionResultFuture.complete(baseRows); + } catch (Throwable t) { + // we should catch the exception here to let the framework know + completeExceptionally(t); + return; + } + + Collection rightRows = delegate.collection; + if (rightRows == null || rightRows.isEmpty()) { + if (isLeftOuterJoin) { + BaseRow outRow = new JoinedRow(leftRow, nullRow); + outRow.setHeader(leftRow.getHeader()); + realOutput.complete(Collections.singleton(outRow)); + } else { + realOutput.complete(Collections.emptyList()); + } + } else { + List outRows = new ArrayList<>(); + for (BaseRow rightRow : rightRows) { + BaseRow outRow = new JoinedRow(leftRow, rightRow); + outRow.setHeader(leftRow.getHeader()); + outRows.add(outRow); + } + realOutput.complete(outRows); + } + try { + // put this collector to the queue to avoid this collector is used + // again before outRows in the collector is not consumed. + resultFutureBuffer.put(this); + } catch (InterruptedException e) { + completeExceptionally(e); + } + } + + @Override + public void completeExceptionally(Throwable error) { + realOutput.completeExceptionally(error); + } + + public void close() throws Exception { + joinConditionResultFuture.close(); + } + + private final class DelegateResultFuture implements ResultFuture { + + private Collection collection; + + public void reset() { + this.collection = null; + } + + @Override + public void complete(Collection result) { + this.collection = result; + } + + @Override + public void completeExceptionally(Throwable error) { + JoinedRowResultFuture.this.completeExceptionally(error); + } + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/AsyncLookupJoinWithCalcRunner.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/AsyncLookupJoinWithCalcRunner.java new file mode 100644 index 0000000000000..925d52e408e6f --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/AsyncLookupJoinWithCalcRunner.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join.lookup; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedFunction; +import org.apache.flink.table.generated.GeneratedResultFuture; +import org.apache.flink.table.runtime.collector.TableFunctionResultFuture; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import java.util.ArrayList; +import java.util.Collection; + +/** + * The async join runner with an additional calculate function on the dimension table. + */ +public class AsyncLookupJoinWithCalcRunner extends AsyncLookupJoinRunner { + + private static final long serialVersionUID = 8758670006385551407L; + + private final GeneratedFunction> generatedCalc; + private final BaseRowTypeInfo rightRowTypeInfo; + private transient TypeSerializer rightSerializer; + + public AsyncLookupJoinWithCalcRunner( + GeneratedFunction> generatedFetcher, + GeneratedFunction> generatedCalc, + GeneratedResultFuture> generatedResultFuture, + TypeInformation fetcherReturnType, + BaseRowTypeInfo rightRowTypeInfo, + boolean isLeftOuterJoin, + int asyncBufferCapacity) { + super(generatedFetcher, generatedResultFuture, fetcherReturnType, + rightRowTypeInfo, isLeftOuterJoin, asyncBufferCapacity); + this.rightRowTypeInfo = rightRowTypeInfo; + this.generatedCalc = generatedCalc; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + // try to compile the generated ResultFuture, fail fast if the code is corrupt. + generatedCalc.compile(getRuntimeContext().getUserCodeClassLoader()); + rightSerializer = rightRowTypeInfo.createSerializer(getRuntimeContext().getExecutionConfig()); + } + + @Override + public TableFunctionResultFuture createFetcherResultFuture(Configuration parameters) throws Exception { + TableFunctionResultFuture joinConditionCollector = super.createFetcherResultFuture(parameters); + FlatMapFunction calc = generatedCalc.newInstance(getRuntimeContext().getUserCodeClassLoader()); + FunctionUtils.setFunctionRuntimeContext(calc, getRuntimeContext()); + FunctionUtils.openFunction(calc, parameters); + return new TemporalTableCalcResultFuture(calc, joinConditionCollector); + } + + @Override + public void close() throws Exception { + super.close(); + } + + private class TemporalTableCalcResultFuture extends TableFunctionResultFuture { + + private static final long serialVersionUID = -6360673852888872924L; + + private final FlatMapFunction calc; + private final TableFunctionResultFuture joinConditionResultFuture; + private final CalcCollectionCollector calcCollector = new CalcCollectionCollector(); + + private TemporalTableCalcResultFuture( + FlatMapFunction calc, + TableFunctionResultFuture joinConditionResultFuture) { + this.calc = calc; + this.joinConditionResultFuture = joinConditionResultFuture; + } + + @Override + public void setInput(Object input) { + joinConditionResultFuture.setInput(input); + calcCollector.reset(); + } + + @Override + public void setResultFuture(ResultFuture resultFuture) { + joinConditionResultFuture.setResultFuture(resultFuture); + } + + @Override + public void complete(Collection result) { + if (result == null || result.size() == 0) { + joinConditionResultFuture.complete(result); + } else { + for (BaseRow row : result) { + try { + calc.flatMap(row, calcCollector); + } catch (Exception e) { + joinConditionResultFuture.completeExceptionally(e); + } + } + joinConditionResultFuture.complete(calcCollector.collection); + } + } + + @Override + public void close() throws Exception { + super.close(); + joinConditionResultFuture.close(); + FunctionUtils.closeFunction(calc); + } + } + + private class CalcCollectionCollector implements Collector { + + Collection collection; + + public void reset() { + this.collection = new ArrayList<>(); + } + + @Override + public void collect(BaseRow record) { + this.collection.add(rightSerializer.copy(record)); + } + + @Override + public void close() { + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/LookupJoinRunner.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/LookupJoinRunner.java new file mode 100644 index 0000000000000..58bcbe7f8fc36 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/LookupJoinRunner.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join.lookup; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.ProcessFunction; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.generated.GeneratedCollector; +import org.apache.flink.table.generated.GeneratedFunction; +import org.apache.flink.table.runtime.collector.TableFunctionCollector; +import org.apache.flink.util.Collector; + +/** + * The join runner to lookup the dimension table. + */ +public class LookupJoinRunner extends ProcessFunction { + private static final long serialVersionUID = -4521543015709964733L; + + private final GeneratedFunction> generatedFetcher; + private final GeneratedCollector> generatedCollector; + private final boolean isLeftOuterJoin; + private final int tableFieldsCount; + + private transient FlatMapFunction fetcher; + protected transient TableFunctionCollector collector; + private transient GenericRow nullRow; + private transient JoinedRow outRow; + + public LookupJoinRunner( + GeneratedFunction> generatedFetcher, + GeneratedCollector> generatedCollector, + boolean isLeftOuterJoin, + int tableFieldsCount) { + this.generatedFetcher = generatedFetcher; + this.generatedCollector = generatedCollector; + this.isLeftOuterJoin = isLeftOuterJoin; + this.tableFieldsCount = tableFieldsCount; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + this.fetcher = generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader()); + this.collector = generatedCollector.newInstance(getRuntimeContext().getUserCodeClassLoader()); + + FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext()); + FunctionUtils.setFunctionRuntimeContext(collector, getRuntimeContext()); + FunctionUtils.openFunction(fetcher, parameters); + FunctionUtils.openFunction(collector, parameters); + + this.nullRow = new GenericRow(tableFieldsCount); + this.outRow = new JoinedRow(); + } + + @Override + public void processElement(BaseRow in, Context ctx, Collector out) throws Exception { + collector.setCollector(out); + collector.setInput(in); + collector.reset(); + + // fetcher has copied the input field when object reuse is enabled + fetcher.flatMap(in, getFetcherCollector()); + + if (isLeftOuterJoin && !collector.isCollected()) { + outRow.replace(in, nullRow); + outRow.setHeader(in.getHeader()); + out.collect(outRow); + } + } + + public Collector getFetcherCollector() { + return collector; + } + + @Override + public void close() throws Exception { + super.close(); + if (fetcher != null) { + FunctionUtils.closeFunction(fetcher); + } + if (collector != null) { + FunctionUtils.closeFunction(collector); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/LookupJoinWithCalcRunner.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/LookupJoinWithCalcRunner.java new file mode 100644 index 0000000000000..16702581930a6 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/lookup/LookupJoinWithCalcRunner.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join.lookup; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedCollector; +import org.apache.flink.table.generated.GeneratedFunction; +import org.apache.flink.table.runtime.collector.TableFunctionCollector; +import org.apache.flink.util.Collector; + +/** + * The join runner with an additional calculate function on the dimension table. + */ +public class LookupJoinWithCalcRunner extends LookupJoinRunner { + + private static final long serialVersionUID = 5277183384939603386L; + private final GeneratedFunction> generatedCalc; + + private transient FlatMapFunction calc; + private transient Collector calcCollector; + + public LookupJoinWithCalcRunner( + GeneratedFunction> generatedFetcher, + GeneratedFunction> generatedCalc, + GeneratedCollector> generatedCollector, + boolean isLeftOuterJoin, + int tableFieldsCount) { + super(generatedFetcher, generatedCollector, isLeftOuterJoin, tableFieldsCount); + this.generatedCalc = generatedCalc; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + this.calc = generatedCalc.newInstance(getRuntimeContext().getUserCodeClassLoader()); + FunctionUtils.setFunctionRuntimeContext(calc, getRuntimeContext()); + FunctionUtils.openFunction(calc, parameters); + this.calcCollector = new CalcCollector(collector); + } + + @Override + public void close() throws Exception { + super.close(); + FunctionUtils.closeFunction(calc); + } + + @Override + public Collector getFetcherCollector() { + return calcCollector; + } + + private class CalcCollector implements Collector { + + private final Collector delegate; + + private CalcCollector(Collector delegate) { + this.delegate = delegate; + } + + @Override + public void collect(BaseRow record) { + try { + calc.flatMap(record, delegate); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + delegate.close(); + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedCollectorWrapper.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedCollectorWrapper.java new file mode 100644 index 0000000000000..2e3744cbc7796 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedCollectorWrapper.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.generated; + +import org.apache.flink.util.Collector; + +/** + * A wrapper for {@link GeneratedCollector} which wraps a class instead of generated code in it. + * It is only used for easy testing. + */ +public class GeneratedCollectorWrapper> extends GeneratedCollector { + + private static final long serialVersionUID = 3964204655565783705L; + private final Class clazz; + + public GeneratedCollectorWrapper(C collector) { + super(collector.getClass().getSimpleName(), "N/A", new Object[0]); + //noinspection unchecked + this.clazz = (Class) collector.getClass(); + } + + @Override + public C newInstance(ClassLoader classLoader) { + try { + return clazz.newInstance(); + } catch (Exception e) { + throw new RuntimeException("Could not instantiate class " + clazz.getCanonicalName(), e); + } + } + + @Override + public Class compile(ClassLoader classLoader) { + return clazz; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedFunctionWrapper.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedFunctionWrapper.java new file mode 100644 index 0000000000000..934eca281736a --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedFunctionWrapper.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.generated; + +import org.apache.flink.api.common.functions.Function; + +/** + * A wrapper for {@link GeneratedFunction} which wraps a class instead of generated code in it. + * It is only used for easy testing. + */ +public class GeneratedFunctionWrapper extends GeneratedFunction { + + private static final long serialVersionUID = 3964204655565783705L; + private final Class clazz; + + public GeneratedFunctionWrapper(F function) { + super(function.getClass().getSimpleName(), "N/A", new Object[0]); + //noinspection unchecked + this.clazz = (Class) function.getClass(); + } + + @Override + public F newInstance(ClassLoader classLoader) { + try { + return clazz.newInstance(); + } catch (Exception e) { + throw new RuntimeException("Could not instantiate class " + clazz.getCanonicalName(), e); + } + } + + @Override + public Class compile(ClassLoader classLoader) { + return clazz; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedResultFutureWrapper.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedResultFutureWrapper.java new file mode 100644 index 0000000000000..ce7a096112628 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/generated/GeneratedResultFutureWrapper.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.generated; + +import org.apache.flink.streaming.api.functions.async.ResultFuture; + +/** + * A wrapper for {@link GeneratedResultFuture} which wraps a class instead of generated code in it. + * It is only used for easy testing. + */ +public class GeneratedResultFutureWrapper> extends GeneratedResultFuture { + + private static final long serialVersionUID = 3964204655565783705L; + private final Class clazz; + + public GeneratedResultFutureWrapper(T resultFuture) { + super(resultFuture.getClass().getSimpleName(), "N/A", new Object[0]); + //noinspection unchecked + this.clazz = (Class) resultFuture.getClass(); + } + + @Override + public T newInstance(ClassLoader classLoader) { + try { + return clazz.newInstance(); + } catch (Exception e) { + throw new RuntimeException("Could not instantiate class " + clazz.getCanonicalName(), e); + } + } + + @Override + public Class compile(ClassLoader classLoader) { + return clazz; + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/AsyncLookupJoinHarnessTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/AsyncLookupJoinHarnessTest.java new file mode 100644 index 0000000000000..969346a50ef2f --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/AsyncLookupJoinHarnessTest.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.datastream.AsyncDataStream; +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.functions.async.RichAsyncFunction; +import org.apache.flink.streaming.api.operators.async.AsyncWaitOperator; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryString; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.generated.GeneratedFunctionWrapper; +import org.apache.flink.table.generated.GeneratedResultFutureWrapper; +import org.apache.flink.table.runtime.collector.TableFunctionCollector; +import org.apache.flink.table.runtime.collector.TableFunctionResultFuture; +import org.apache.flink.table.runtime.join.lookup.AsyncLookupJoinRunner; +import org.apache.flink.table.runtime.join.lookup.AsyncLookupJoinWithCalcRunner; +import org.apache.flink.table.runtime.join.lookup.LookupJoinRunner; +import org.apache.flink.table.runtime.join.lookup.LookupJoinWithCalcRunner; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowSerializer; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.Collector; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.Supplier; + +import static org.apache.flink.table.dataformat.BinaryString.fromString; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; + +/** + * Harness tests for {@link LookupJoinRunner} and {@link LookupJoinWithCalcRunner}. + */ +public class AsyncLookupJoinHarnessTest { + + private static final int ASYNC_BUFFER_CAPACITY = 100; + private static final int ASYNC_TIMEOUT_MS = 3000; + + private final TypeSerializer inSerializer = new BaseRowSerializer( + new ExecutionConfig(), + InternalTypes.INT, + InternalTypes.STRING); + + private final BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor(new TypeInformation[]{ + Types.INT, + Types.STRING, + Types.INT, + Types.STRING + }); + + private BaseRowTypeInfo rightRowTypeInfo = new BaseRowTypeInfo(InternalTypes.INT, InternalTypes.STRING); + private TypeInformation fetcherReturnType = rightRowTypeInfo; + + @Test + public void testTemporalInnerAsyncJoin() throws Exception { + OneInputStreamOperatorTestHarness testHarness = createHarness( + JoinType.INNER_JOIN, + FilterOnTable.WITHOUT_FILTER); + + testHarness.open(); + + synchronized (testHarness.getCheckpointLock()) { + testHarness.processElement(record(1, "a")); + testHarness.processElement(record(2, "b")); + testHarness.processElement(record(3, "c")); + testHarness.processElement(record(4, "d")); + testHarness.processElement(record(5, "e")); + } + + // wait until all async collectors in the buffer have been emitted out. + synchronized (testHarness.getCheckpointLock()) { + testHarness.close(); + } + + List expectedOutput = new ArrayList<>(); + expectedOutput.add(record(1, "a", 1, "Julian")); + expectedOutput.add(record(3, "c", 3, "Jark")); + expectedOutput.add(record(3, "c", 3, "Jackson")); + expectedOutput.add(record(4, "d", 4, "Fabian")); + + assertor.assertOutputEquals("output wrong.", expectedOutput, testHarness.getOutput()); + } + + @Test + public void testTemporalInnerAsyncJoinWithFilter() throws Exception { + OneInputStreamOperatorTestHarness testHarness = createHarness( + JoinType.INNER_JOIN, + FilterOnTable.WITH_FILTER); + + testHarness.open(); + + synchronized (testHarness.getCheckpointLock()) { + testHarness.processElement(record(1, "a")); + testHarness.processElement(record(2, "b")); + testHarness.processElement(record(3, "c")); + testHarness.processElement(record(4, "d")); + testHarness.processElement(record(5, "e")); + } + + // wait until all async collectors in the buffer have been emitted out. + synchronized (testHarness.getCheckpointLock()) { + testHarness.close(); + } + + List expectedOutput = new ArrayList<>(); + expectedOutput.add(record(1, "a", 1, "Julian")); + expectedOutput.add(record(3, "c", 3, "Jackson")); + expectedOutput.add(record(4, "d", 4, "Fabian")); + + assertor.assertOutputEquals("output wrong.", expectedOutput, testHarness.getOutput()); + } + + @Test + public void testTemporalLeftAsyncJoin() throws Exception { + OneInputStreamOperatorTestHarness testHarness = createHarness( + JoinType.LEFT_JOIN, + FilterOnTable.WITHOUT_FILTER); + + testHarness.open(); + + synchronized (testHarness.getCheckpointLock()) { + testHarness.processElement(record(1, "a")); + testHarness.processElement(record(2, "b")); + testHarness.processElement(record(3, "c")); + testHarness.processElement(record(4, "d")); + testHarness.processElement(record(5, "e")); + } + + // wait until all async collectors in the buffer have been emitted out. + synchronized (testHarness.getCheckpointLock()) { + testHarness.close(); + } + + List expectedOutput = new ArrayList<>(); + expectedOutput.add(record(1, "a", 1, "Julian")); + expectedOutput.add(record(2, "b", null, null)); + expectedOutput.add(record(3, "c", 3, "Jark")); + expectedOutput.add(record(3, "c", 3, "Jackson")); + expectedOutput.add(record(4, "d", 4, "Fabian")); + expectedOutput.add(record(5, "e", null, null)); + + assertor.assertOutputEquals("output wrong.", expectedOutput, testHarness.getOutput()); + } + + @Test + public void testTemporalLeftAsyncJoinWithFilter() throws Exception { + OneInputStreamOperatorTestHarness testHarness = createHarness( + JoinType.LEFT_JOIN, + FilterOnTable.WITH_FILTER); + + testHarness.open(); + + synchronized (testHarness.getCheckpointLock()) { + testHarness.processElement(record(1, "a")); + testHarness.processElement(record(2, "b")); + testHarness.processElement(record(3, "c")); + testHarness.processElement(record(4, "d")); + testHarness.processElement(record(5, "e")); + } + + // wait until all async collectors in the buffer have been emitted out. + synchronized (testHarness.getCheckpointLock()) { + testHarness.close(); + } + + List expectedOutput = new ArrayList<>(); + expectedOutput.add(record(1, "a", 1, "Julian")); + expectedOutput.add(record(2, "b", null, null)); + expectedOutput.add(record(3, "c", 3, "Jackson")); + expectedOutput.add(record(4, "d", 4, "Fabian")); + expectedOutput.add(record(5, "e", null, null)); + + assertor.assertOutputEquals("output wrong.", expectedOutput, testHarness.getOutput()); + } + + // --------------------------------------------------------------------------------- + + @SuppressWarnings("unchecked") + private OneInputStreamOperatorTestHarness createHarness( + JoinType joinType, + FilterOnTable filterOnTable) throws Exception { + RichAsyncFunction joinRunner; + boolean isLeftJoin = joinType == JoinType.LEFT_JOIN; + if (filterOnTable == FilterOnTable.WITHOUT_FILTER) { + joinRunner = new AsyncLookupJoinRunner( + new GeneratedFunctionWrapper(new TestingFetcherFunction()), + new GeneratedResultFutureWrapper<>(new TestingFetcherResultFuture()), + fetcherReturnType, + rightRowTypeInfo, + isLeftJoin, + ASYNC_BUFFER_CAPACITY); + } else { + joinRunner = new AsyncLookupJoinWithCalcRunner( + new GeneratedFunctionWrapper(new TestingFetcherFunction()), + new GeneratedFunctionWrapper<>(new CalculateOnTemporalTable()), + new GeneratedResultFutureWrapper<>(new TestingFetcherResultFuture()), + fetcherReturnType, + rightRowTypeInfo, + isLeftJoin, + ASYNC_BUFFER_CAPACITY); + } + + AsyncWaitOperator operator = new AsyncWaitOperator<>( + joinRunner, + ASYNC_TIMEOUT_MS, + ASYNC_BUFFER_CAPACITY, + AsyncDataStream.OutputMode.ORDERED); + + return new OneInputStreamOperatorTestHarness<>( + operator, + inSerializer); + } + + /** + * Whether this is a inner join or left join. + */ + private enum JoinType { + INNER_JOIN, + LEFT_JOIN + } + + /** + * Whether there is a filter on temporal table. + */ + private enum FilterOnTable { + WITH_FILTER, + WITHOUT_FILTER + } + + // --------------------------------------------------------------------------------- + + + /** + * The {@link TestingFetcherFunction} only accepts a single integer lookup key and + * returns zero or one or more BaseRows. + */ + public static final class TestingFetcherFunction + extends AbstractRichFunction + implements AsyncFunction { + + private static final long serialVersionUID = 4018474964018227081L; + + private static final Map> data = new HashMap<>(); + + static { + data.put(1, Collections.singletonList( + GenericRow.of(1, fromString("Julian")))); + data.put(3, Arrays.asList( + GenericRow.of(3, fromString("Jark")), + GenericRow.of(3, fromString("Jackson")))); + data.put(4, Collections.singletonList( + GenericRow.of(4, fromString("Fabian")))); + } + + private transient ExecutorService executor; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + this.executor = Executors.newSingleThreadExecutor(); + } + + @Override + public void asyncInvoke(BaseRow input, ResultFuture resultFuture) throws Exception { + int id = input.getInt(0); + CompletableFuture + .supplyAsync((Supplier>) () -> data.get(id), executor) + .thenAcceptAsync(resultFuture::complete, executor); + } + + @Override + public void close() throws Exception { + super.close(); + if (null != executor && !executor.isShutdown()) { + executor.shutdown(); + } + } + } + + /** + * The {@link TestingFetcherResultFuture} is a simple implementation of + * {@link TableFunctionCollector} which forwards the collected collection. + */ + public static final class TestingFetcherResultFuture extends TableFunctionResultFuture { + private static final long serialVersionUID = -312754413938303160L; + + @Override + public void complete(Collection result) { + //noinspection unchecked + getResultFuture().complete((Collection) result); + } + } + + /** + * The {@link CalculateOnTemporalTable} is a filter on temporal table which only accepts + * length of name greater than or equal to 6. + */ + public static final class CalculateOnTemporalTable implements FlatMapFunction { + + private static final long serialVersionUID = -1860345072157431136L; + + @Override + public void flatMap(BaseRow value, Collector out) throws Exception { + BinaryString name = value.getString(1); + if (name.getSizeInBytes() >= 6) { + out.collect(value); + } + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/LookupJoinHarnessTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/LookupJoinHarnessTest.java new file mode 100644 index 0000000000000..5db79790f99a3 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/LookupJoinHarnessTest.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.streaming.api.functions.ProcessFunction; +import org.apache.flink.streaming.api.operators.ProcessOperator; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryString; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.generated.GeneratedCollectorWrapper; +import org.apache.flink.table.generated.GeneratedFunctionWrapper; +import org.apache.flink.table.runtime.collector.TableFunctionCollector; +import org.apache.flink.table.runtime.join.lookup.LookupJoinRunner; +import org.apache.flink.table.runtime.join.lookup.LookupJoinWithCalcRunner; +import org.apache.flink.table.runtime.util.BaseRowHarnessAssertor; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.typeutils.BaseRowSerializer; +import org.apache.flink.util.Collector; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.dataformat.BinaryString.fromString; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.record; + +/** + * Harness tests for {@link LookupJoinRunner} and {@link LookupJoinWithCalcRunner}. + */ +public class LookupJoinHarnessTest { + + private final TypeSerializer inSerializer = new BaseRowSerializer( + new ExecutionConfig(), + InternalTypes.INT, + InternalTypes.STRING); + + private final BaseRowHarnessAssertor assertor = new BaseRowHarnessAssertor(new TypeInformation[]{ + Types.INT, + Types.STRING, + Types.INT, + Types.STRING + }); + + @Test + public void testTemporalInnerJoin() throws Exception { + OneInputStreamOperatorTestHarness testHarness = createHarness( + JoinType.INNER_JOIN, + FilterOnTable.WITHOUT_FILTER); + + testHarness.open(); + + testHarness.processElement(record(1, "a")); + testHarness.processElement(record(2, "b")); + testHarness.processElement(record(3, "c")); + testHarness.processElement(record(4, "d")); + testHarness.processElement(record(5, "e")); + + List expectedOutput = new ArrayList<>(); + expectedOutput.add(record(1, "a", 1, "Julian")); + expectedOutput.add(record(3, "c", 3, "Jark")); + expectedOutput.add(record(3, "c", 3, "Jackson")); + expectedOutput.add(record(4, "d", 4, "Fabian")); + + assertor.assertOutputEquals("output wrong.", expectedOutput, testHarness.getOutput()); + testHarness.close(); + } + + @Test + public void testTemporalInnerJoinWithFilter() throws Exception { + OneInputStreamOperatorTestHarness testHarness = createHarness( + JoinType.INNER_JOIN, + FilterOnTable.WITH_FILTER); + + testHarness.open(); + + testHarness.processElement(record(1, "a")); + testHarness.processElement(record(2, "b")); + testHarness.processElement(record(3, "c")); + testHarness.processElement(record(4, "d")); + testHarness.processElement(record(5, "e")); + + List expectedOutput = new ArrayList<>(); + expectedOutput.add(record(1, "a", 1, "Julian")); + expectedOutput.add(record(3, "c", 3, "Jackson")); + expectedOutput.add(record(4, "d", 4, "Fabian")); + + assertor.assertOutputEquals("output wrong.", expectedOutput, testHarness.getOutput()); + testHarness.close(); + } + + @Test + public void testTemporalLeftJoin() throws Exception { + OneInputStreamOperatorTestHarness testHarness = createHarness( + JoinType.LEFT_JOIN, + FilterOnTable.WITHOUT_FILTER); + + testHarness.open(); + + testHarness.processElement(record(1, "a")); + testHarness.processElement(record(2, "b")); + testHarness.processElement(record(3, "c")); + testHarness.processElement(record(4, "d")); + testHarness.processElement(record(5, "e")); + + List expectedOutput = new ArrayList<>(); + expectedOutput.add(record(1, "a", 1, "Julian")); + expectedOutput.add(record(2, "b", null, null)); + expectedOutput.add(record(3, "c", 3, "Jark")); + expectedOutput.add(record(3, "c", 3, "Jackson")); + expectedOutput.add(record(4, "d", 4, "Fabian")); + expectedOutput.add(record(5, "e", null, null)); + + assertor.assertOutputEquals("output wrong.", expectedOutput, testHarness.getOutput()); + testHarness.close(); + } + + @Test + public void testTemporalLeftJoinWithFilter() throws Exception { + OneInputStreamOperatorTestHarness testHarness = createHarness( + JoinType.LEFT_JOIN, + FilterOnTable.WITH_FILTER); + + testHarness.open(); + + testHarness.processElement(record(1, "a")); + testHarness.processElement(record(2, "b")); + testHarness.processElement(record(3, "c")); + testHarness.processElement(record(4, "d")); + testHarness.processElement(record(5, "e")); + + List expectedOutput = new ArrayList<>(); + expectedOutput.add(record(1, "a", 1, "Julian")); + expectedOutput.add(record(2, "b", null, null)); + expectedOutput.add(record(3, "c", 3, "Jackson")); + expectedOutput.add(record(4, "d", 4, "Fabian")); + expectedOutput.add(record(5, "e", null, null)); + + assertor.assertOutputEquals("output wrong.", expectedOutput, testHarness.getOutput()); + testHarness.close(); + } + + // --------------------------------------------------------------------------------- + + @SuppressWarnings("unchecked") + private OneInputStreamOperatorTestHarness createHarness( + JoinType joinType, + FilterOnTable filterOnTable) throws Exception { + boolean isLeftJoin = joinType == JoinType.LEFT_JOIN; + ProcessFunction joinRunner; + if (filterOnTable == FilterOnTable.WITHOUT_FILTER) { + joinRunner = new LookupJoinRunner( + new GeneratedFunctionWrapper<>(new TestingFetcherFunction()), + new GeneratedCollectorWrapper<>(new TestingFetcherCollector()), + isLeftJoin, + 2); + } else { + joinRunner = new LookupJoinWithCalcRunner( + new GeneratedFunctionWrapper<>(new TestingFetcherFunction()), + new GeneratedFunctionWrapper<>(new CalculateOnTemporalTable()), + new GeneratedCollectorWrapper<>(new TestingFetcherCollector()), + isLeftJoin, + 2); + } + + ProcessOperator operator = new ProcessOperator<>(joinRunner); + return new OneInputStreamOperatorTestHarness<>( + operator, + inSerializer); + } + + /** + * Whether this is a inner join or left join. + */ + private enum JoinType { + INNER_JOIN, + LEFT_JOIN + } + + /** + * Whether there is a filter on temporal table. + */ + private enum FilterOnTable { + WITH_FILTER, + WITHOUT_FILTER + } + + // --------------------------------------------------------------------------------- + + /** + * The {@link TestingFetcherFunction} only accepts a single integer lookup key and + * returns zero or one or more BaseRows. + */ + public static final class TestingFetcherFunction implements FlatMapFunction { + + private static final long serialVersionUID = 4018474964018227081L; + + private static final Map> data = new HashMap<>(); + + static { + data.put(1, Collections.singletonList( + GenericRow.of(1, fromString("Julian")))); + data.put(3, Arrays.asList( + GenericRow.of(3, fromString("Jark")), + GenericRow.of(3, fromString("Jackson")))); + data.put(4, Collections.singletonList( + GenericRow.of(4, fromString("Fabian")))); + } + + @Override + public void flatMap(BaseRow value, Collector out) throws Exception { + int id = value.getInt(0); + List rows = data.get(id); + if (rows != null) { + for (GenericRow row : rows) { + out.collect(row); + } + } + } + } + + /** + * The {@link TestingFetcherCollector} is a simple implementation of + * {@link TableFunctionCollector} which combines left and right into a JoinedRow. + */ + public static final class TestingFetcherCollector extends TableFunctionCollector { + private static final long serialVersionUID = -312754413938303160L; + + @Override + public void collect(Object record) { + BaseRow left = (BaseRow) getInput(); + BaseRow right = (BaseRow) record; + outputResult(new JoinedRow(left, right)); + } + } + + /** + * The {@link CalculateOnTemporalTable} is a filter on temporal table which only accepts + * length of name greater than or equal to 6. + */ + public static final class CalculateOnTemporalTable implements FlatMapFunction { + + private static final long serialVersionUID = -1860345072157431136L; + + @Override + public void flatMap(BaseRow value, Collector out) throws Exception { + BinaryString name = value.getString(1); + if (name.getSizeInBytes() >= 6) { + out.collect(value); + } + } + } +}