Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.streaming.api.datastream.AsyncDataStream;
import org.apache.flink.streaming.api.functions.ProcessFunction;
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
import org.apache.flink.streaming.api.operators.ProcessOperator;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.catalog.DataTypeFactory;
Expand Down Expand Up @@ -57,23 +62,28 @@
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.schema.LegacyTableSourceTable;
import org.apache.flink.table.planner.plan.schema.TableSourceTable;
import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
import org.apache.flink.table.planner.plan.utils.LookupJoinUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.runtime.collector.TableFunctionCollector;
import org.apache.flink.table.runtime.collector.ListenableCollector;
import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
import org.apache.flink.table.runtime.generated.GeneratedCollector;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
import org.apache.flink.table.runtime.keyselector.EmptyRowDataKeySelector;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinRunner;
import org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinWithCalcRunner;
import org.apache.flink.table.runtime.operators.join.lookup.KeyedLookupJoinWrapper;
import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinRunner;
import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinWithCalcRunner;
import org.apache.flink.table.runtime.types.PlannerTypeUtils;
import org.apache.flink.table.runtime.types.TypeInfoDataTypeConverter;
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.runtime.util.StateConfigUtil;
import org.apache.flink.table.sources.LookupableTableSource;
import org.apache.flink.table.sources.TableSource;
import org.apache.flink.table.types.logical.LogicalType;
Expand All @@ -96,6 +106,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
import static org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory;
Expand Down Expand Up @@ -143,6 +154,8 @@ public abstract class CommonExecLookupJoin extends ExecNodeBase<RowData>

public static final String LOOKUP_JOIN_TRANSFORMATION = "lookup-join";

public static final String LOOKUP_JOIN_MATERIALIZE_TRANSFORMATION = "lookup-join-materialize";

public static final String FIELD_NAME_JOIN_TYPE = "joinType";
public static final String FIELD_NAME_JOIN_CONDITION = "joinCondition";
public static final String FIELD_NAME_TEMPORAL_TABLE = "temporalTable";
Expand Down Expand Up @@ -329,17 +342,71 @@ private Transformation<RowData> createSyncLookupJoinWithState(
isLeftOuterJoin,
isObjectReuseEnabled);

// TODO then wrapper it into a keyed lookup function with state FLINK-28568
throw new UnsupportedOperationException("to be supported");
}

private LogicalType getLookupKeyLogicalType(
LookupJoinUtil.LookupKey lookupKey, RowType inputRowType) {
if (lookupKey instanceof LookupJoinUtil.FieldRefLookupKey) {
return inputRowType.getTypeAt(((LookupJoinUtil.FieldRefLookupKey) lookupKey).index);
RowType rightRowType =
getRightOutputRowType(
getProjectionOutputRelDataType(relBuilder), tableSourceRowType);

KeyedLookupJoinWrapper keyedLookupJoinWrapper =
new KeyedLookupJoinWrapper(
(LookupJoinRunner) processFunction,
StateConfigUtil.createTtlConfig(
config.get(ExecutionConfigOptions.IDLE_STATE_RETENTION).toMillis()),
InternalSerializers.create(rightRowType),
lookupKeyContainsPrimaryKey);

KeyedProcessOperator<RowData, RowData, RowData> operator =
new KeyedProcessOperator<>(keyedLookupJoinWrapper);

List<Integer> refKeys =
allLookupKeys.values().stream()
.filter(key -> key instanceof LookupJoinUtil.FieldRefLookupKey)
.map(key -> ((LookupJoinUtil.FieldRefLookupKey) key).index)
.collect(Collectors.toList());
RowDataKeySelector keySelector;

// use single parallelism for empty key shuffle
boolean singleParallelism = refKeys.isEmpty();
if (singleParallelism) {
// all lookup keys are constants, then use an empty key selector
keySelector = EmptyRowDataKeySelector.INSTANCE;
} else {
return ((LookupJoinUtil.ConstantLookupKey) lookupKey).sourceType;
// make it a deterministic asc order
Collections.sort(refKeys);
keySelector =
KeySelectorUtil.getRowDataSelector(
classLoader,
refKeys.stream().mapToInt(Integer::intValue).toArray(),
InternalTypeInfo.of(inputRowType));
}
final KeyGroupStreamPartitioner<RowData, RowData> partitioner =
new KeyGroupStreamPartitioner<>(
keySelector, KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM);
Transformation<RowData> partitionedTransform =
new PartitionTransformation<>(inputTransformation, partitioner);
if (singleParallelism) {
setSingletonParallelism(partitionedTransform);
} else {
partitionedTransform.setParallelism(inputTransformation.getParallelism());
}

OneInputTransformation<RowData, RowData> transform =
ExecNodeUtil.createOneInputTransformation(
partitionedTransform,
createTransformationMeta(LOOKUP_JOIN_MATERIALIZE_TRANSFORMATION, config),
operator,
InternalTypeInfo.of(resultRowType),
partitionedTransform.getParallelism());
transform.setStateKeySelector(keySelector);
transform.setStateKeyType(keySelector.getProducedType());
if (singleParallelism) {
setSingletonParallelism(transform);
}
return transform;
}

private void setSingletonParallelism(Transformation transformation) {
transformation.setParallelism(1);
transformation.setMaxParallelism(1);
}

protected void validateLookupKeyType(
Expand Down Expand Up @@ -413,16 +480,9 @@ private StreamOperatorFactory<RowData> createAsyncLookupJoin(
asyncLookupFunction,
StringUtils.join(temporalTable.getQualifiedName(), "."));

Optional<RelDataType> temporalTableOutputType =
projectionOnTemporalTable != null
? Optional.of(
RexUtil.createStructType(
unwrapTypeFactory(relBuilder), projectionOnTemporalTable))
: Optional.empty();
RelDataType projectionOutputRelDataType = getProjectionOutputRelDataType(relBuilder);
RowType rightRowType =
projectionOnTemporalTable != null
? (RowType) toLogicalType(temporalTableOutputType.get())
: tableSourceRowType;
getRightOutputRowType(projectionOutputRelDataType, tableSourceRowType);
// a projection or filter after table source scan
GeneratedResultFuture<TableFunctionResultFuture<RowData>> generatedResultFuture =
LookupJoinCodeGenerator.generateTableAsyncCollector(
Expand All @@ -444,7 +504,7 @@ private StreamOperatorFactory<RowData> createAsyncLookupJoin(
classLoader,
JavaScalaConversionUtil.toScala(projectionOnTemporalTable),
filterOnTemporalTable,
temporalTableOutputType.get(),
projectionOutputRelDataType,
tableSourceRowType);
asyncFunc =
new AsyncLookupJoinWithCalcRunner(
Expand Down Expand Up @@ -508,6 +568,19 @@ private StreamOperatorFactory<RowData> createSyncLookupJoin(
isObjectReuseEnabled)));
}

private RelDataType getProjectionOutputRelDataType(RelBuilder relBuilder) {
return projectionOnTemporalTable != null
? RexUtil.createStructType(unwrapTypeFactory(relBuilder), projectionOnTemporalTable)
: null;
}

private RowType getRightOutputRowType(
RelDataType projectionOutputRelDataType, RowType tableSourceRowType) {
return projectionOutputRelDataType != null
? (RowType) toLogicalType(projectionOutputRelDataType)
: tableSourceRowType;
}

private ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
RelOptTable temporalTable,
ExecNodeConfig config,
Expand Down Expand Up @@ -540,17 +613,10 @@ private ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
StringUtils.join(temporalTable.getQualifiedName(), "."),
isObjectReuseEnabled);

Optional<RelDataType> temporalTableOutputType =
projectionOnTemporalTable != null
? Optional.of(
RexUtil.createStructType(
unwrapTypeFactory(relBuilder), projectionOnTemporalTable))
: Optional.empty();
RelDataType projectionOutputRelDataType = getProjectionOutputRelDataType(relBuilder);
RowType rightRowType =
projectionOnTemporalTable != null
? (RowType) toLogicalType(temporalTableOutputType.get())
: tableSourceRowType;
GeneratedCollector<TableFunctionCollector<RowData>> generatedCollector =
getRightOutputRowType(projectionOutputRelDataType, tableSourceRowType);
GeneratedCollector<ListenableCollector<RowData>> generatedCollector =
LookupJoinCodeGenerator.generateCollector(
new CodeGeneratorContext(config, classLoader),
inputRowType,
Expand All @@ -568,7 +634,7 @@ private ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
classLoader,
JavaScalaConversionUtil.toScala(projectionOnTemporalTable),
filterOnTemporalTable,
temporalTableOutputType.get(),
projectionOutputRelDataType,
tableSourceRowType);

processFunc =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,15 @@ public static UserDefinedFunction getLookupFunction(
LookupTableSource.LookupRuntimeProvider provider =
tableSource.getLookupRuntimeProvider(providerContext);

if (requireSyncLookup && !(provider instanceof TableFunctionProvider)) {
// TODO this method will be refactored in FLINK-28848
if (requireSyncLookup
&& !(provider instanceof TableFunctionProvider)
&& !(provider instanceof LookupFunctionProvider)) {
throw new TableException(
String.format(
"Require a synchronous TableFunction due to planner's requirement but no TableFunctionProvider "
+ "found in TableSourceTable: %s, please check the code to ensure a proper TableFunctionProvider is specified.",
"Require a synchronous lookup function due to planner's requirement but no "
+ "available functions in TableSourceTable: %s, please check the code to ensure "
+ "a proper LookupFunctionProvider or TableFunctionProvider is specified.",
temporalTable.getQualifiedName()));
}
if (provider instanceof LookupFunctionProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ import org.apache.flink.table.planner.functions.inference.LookupCallContext
import org.apache.flink.table.planner.plan.utils.LookupJoinUtil.{ConstantLookupKey, FieldRefLookupKey, LookupKey}
import org.apache.flink.table.planner.plan.utils.RexLiteralUtil
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
import org.apache.flink.table.runtime.collector.{TableFunctionCollector, TableFunctionResultFuture}
import org.apache.flink.table.runtime.collector.{ListenableCollector, TableFunctionResultFuture}
import org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener
import org.apache.flink.table.runtime.generated.{GeneratedCollector, GeneratedFunction, GeneratedResultFuture}
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.extraction.ExtractionUtils.extractSimpleGeneric
Expand Down Expand Up @@ -299,7 +300,7 @@ object LookupJoinCodeGenerator {
resultRowType: RowType,
condition: Option[RexNode],
pojoFieldMapping: Option[Array[Int]],
retainHeader: Boolean = true): GeneratedCollector[TableFunctionCollector[RowData]] = {
retainHeader: Boolean = true): GeneratedCollector[ListenableCollector[RowData]] = {

val inputTerm = DEFAULT_INPUT1_TERM
val rightInputTerm = DEFAULT_INPUT2_TERM
Expand Down Expand Up @@ -366,15 +367,15 @@ object LookupJoinCodeGenerator {
collectedType: RowType,
inputTerm: String = DEFAULT_INPUT1_TERM,
collectedTerm: String = DEFAULT_INPUT2_TERM)
: GeneratedCollector[TableFunctionCollector[RowData]] = {
: GeneratedCollector[ListenableCollector[RowData]] = {

val funcName = newName(name)
val input1TypeClass = boxedTypeTermForType(inputType)
val input2TypeClass = boxedTypeTermForType(collectedType)

val funcCode =
s"""
public class $funcName extends ${classOf[TableFunctionCollector[_]].getCanonicalName} {
public class $funcName extends ${classOf[ListenableCollector[_]].getCanonicalName} {

${ctx.reuseMemberCode()}

Expand All @@ -391,6 +392,16 @@ object LookupJoinCodeGenerator {
public void collect(Object record) throws Exception {
$input1TypeClass $inputTerm = ($input1TypeClass) getInput();
$input2TypeClass $collectedTerm = ($input2TypeClass) record;

// callback only when collectListener exists, equivalent to:
// getCollectListener().ifPresent(
// listener -> ((CollectListener) listener).onCollect(record));
// TODO we should update code splitter's grammar file to accept lambda expressions.

if (getCollectListener().isPresent()) {
((${classOf[CollectListener[_]].getCanonicalName}) getCollectListener().get()).onCollect(record);
}

${ctx.reuseLocalVariableCode()}
${ctx.reuseInputUnboxingCode()}
${ctx.reusePerRecordCode()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.planner.runtime.utils.InMemoryLookupableTableSource;
import org.apache.flink.table.planner.utils.StreamTableTestUtil;
import org.apache.flink.table.planner.utils.TableTestBase;
Expand Down Expand Up @@ -68,6 +69,16 @@ public void setup() {
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'bounded' = 'false')";
String sinkTable1 =
"CREATE TABLE Sink1 (\n"
+ " a int,\n"
+ " name varchar,"
+ " age int"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'sink-insert-only' = 'false'\n"
+ ")";
tEnv.executeSql(sinkTable1);
tEnv.executeSql(srcTableA);
tEnv.executeSql(srcTableB);
}
Expand Down Expand Up @@ -156,4 +167,19 @@ public void testLegacyTableSourceException() {
ValidationException.class,
"TemporalTableSourceSpec can not be serialized."));
}

@Test
public void testAggAndLeftJoinWithTryResolveMode() {
tEnv.getConfig()
.set(
OptimizerConfigOptions.TABLE_OPTIMIZER_NONDETERMINISTIC_UPDATE_STRATEGY,
OptimizerConfigOptions.NonDeterministicUpdateStrategy.TRY_RESOLVE);

util.verifyJsonPlan(
"INSERT INTO Sink1 "
+ "SELECT T.a, D.name, D.age "
+ "FROM (SELECT max(a) a, count(c) c, PROCTIME() proctime FROM MyTable GROUP BY b) T "
+ "LEFT JOIN LookupTable "
+ "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id");
}
}
Loading