From 5cbf9a670c151b55631138955e97f1ee0d7ff383 Mon Sep 17 00:00:00 2001 From: Xpray Date: Fri, 25 May 2018 23:58:05 +0800 Subject: [PATCH] [FLINK-9344][TableAPI & SQL] Support INTERSECT and INTERSECT ALL for streaming --- .../calcite/RelTimeIndicatorConverter.scala | 79 +++--- .../flink/table/plan/logical/operators.scala | 4 - .../datastream/DataStreamIntersect.scala | 120 +++++++++ .../table/plan/rules/FlinkRuleSets.scala | 1 + .../datastream/DataStreamIntersectRule.scala | 67 +++++ .../runtime/setop/NonWindowIntersect.scala | 242 ++++++++++++++++++ .../api/stream/table/SetOperatorsTest.scala | 81 ++++++ .../SetOperatorsStringExpressionTest.scala | 22 ++ .../SetOperatorsValidationTest.scala | 45 +++- .../UnsupportedOpsValidationTest.scala | 18 -- .../batch/sql/SetOperatorsITCase.scala | 2 - .../harness/NonWindowHarnessTest.scala | 67 +++++ .../stream/sql/SetOperatorsITCase.scala | 127 ++++++++- .../stream/table/SetOperatorsITCase.scala | 86 +++++++ 14 files changed, 897 insertions(+), 64 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamIntersect.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamIntersectRule.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/setop/NonWindowIntersect.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala index 4f3fbaa8edeb3..a01f34357d832 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala @@ -47,45 +47,12 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle { .asInstanceOf[FlinkTypeFactory] .createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP, isNullable = false) - override def visit(intersect: LogicalIntersect): RelNode = - throw new TableException("Logical intersect in a stream environment is not supported yet.") + override def visit(intersect: LogicalIntersect): RelNode = { + visitSetOp(intersect) + } override def visit(union: LogicalUnion): RelNode = { - // visit children and update inputs - val inputs = union.getInputs.map(_.accept(this)) - - // make sure that time indicator types match - val inputTypes = inputs.map(_.getRowType) - - val head = inputTypes.head.getFieldList.map(_.getType) - - val isValid = inputTypes.forall { t => - val fieldTypes = t.getFieldList.map(_.getType) - - fieldTypes.zip(head).forall { case (l, r) => - // check if time indicators match - if (isTimeIndicatorType(l) && isTimeIndicatorType(r)) { - val leftTime = l.asInstanceOf[TimeIndicatorRelDataType].isEventTime - val rightTime = r.asInstanceOf[TimeIndicatorRelDataType].isEventTime - leftTime == rightTime - } - // one side is not an indicator - else if (isTimeIndicatorType(l) || isTimeIndicatorType(r)) { - false - } - // uninteresting types - else { - true - } - } - } - - if (!isValid) { - throw new ValidationException( - "Union fields with time attributes have different types.") - } - - LogicalUnion.create(inputs, union.all) + visitSetOp(union) } override def visit(aggregate: LogicalAggregate): RelNode = convertAggregate(aggregate) @@ -207,6 +174,44 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle { correlate.getJoinType) } + private def visitSetOp(setOp: SetOp): RelNode = { + // visit children and update inputs + val inputs = setOp.getInputs.map(_.accept(this)) + + // make sure that time indicator types match + val inputTypes = inputs.map(_.getRowType) + + val head = inputTypes.head.getFieldList.map(_.getType) + + val isValid = inputTypes.forall { t => + val fieldTypes = t.getFieldList.map(_.getType) + + fieldTypes.zip(head).forall { case (l, r) => + // check if time indicators match + if (isTimeIndicatorType(l) && isTimeIndicatorType(r)) { + val leftTime = l.asInstanceOf[TimeIndicatorRelDataType].isEventTime + val rightTime = r.asInstanceOf[TimeIndicatorRelDataType].isEventTime + leftTime == rightTime + } + // one side is not an indicator + else if (isTimeIndicatorType(l) || isTimeIndicatorType(r)) { + false + } + // uninteresting types + else { + true + } + } + } + + if (!isValid) { + throw new ValidationException( + "SetOp fields with time attributes have different types.") + } + + setOp.copy(setOp.getTraitSet, inputs, setOp.all) + } + private def convertAggregate(aggregate: Aggregate): LogicalAggregate = { // visit children and update inputs val input = aggregate.getInput.accept(this) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index a2bd1e45124fe..ed8d6e5a8f9a6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -341,10 +341,6 @@ case class Intersect(left: LogicalNode, right: LogicalNode, all: Boolean) extend } override def validate(tableEnv: TableEnvironment): LogicalNode = { - if (tableEnv.isInstanceOf[StreamTableEnvironment]) { - failValidation(s"Intersect on stream tables is currently not supported.") - } - val resolvedIntersect = super.validate(tableEnv).asInstanceOf[Intersect] if (left.output.length != right.output.length) { failValidation(s"Intersect two tables of different column sizes:" + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamIntersect.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamIntersect.scala new file mode 100644 index 0000000000000..7dba75a03d8ca --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamIntersect.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} +import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment} +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.table.runtime.CRowKeySelector +import org.apache.flink.table.runtime.setop.NonWindowIntersect +import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} + +import scala.collection.JavaConverters._ + +/** + * RelNode for non-window stream intersect + */ +class DataStreamIntersect( + cluster: RelOptCluster, + traitSet: RelTraitSet, + leftNode: RelNode, + rightNode: RelNode, + rowRelDataType: RelDataType, + all: Boolean) + extends BiRel(cluster, traitSet, leftNode, rightNode) + with DataStreamRel { + + private lazy val intersectType = if (all) { + "IntersectAll" + } else { + "Intersect" + } + + override def needsUpdatesAsRetraction: Boolean = true + + override def deriveRowType() = rowRelDataType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamIntersect( + cluster, + traitSet, + inputs.get(0), + inputs.get(1), + getRowType, + all + ) + } + + override def toString: String = { + s"$intersectType($intersectSelectionToString)" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw).item(s"$intersectType", intersectSelectionToString) + } + + override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + val children = this.getInputs + children.asScala.foldLeft(planner.getCostFactory.makeCost(0, 0, 0)) { (cost, child) => + val rowCnt = metadata.getRowCount(child) + val rowSize = this.estimateRowSize(child.getRowType) + cost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * rowSize)) + } + } + + override def translateToPlan( + tableEnv: StreamTableEnvironment, + queryConfig: StreamQueryConfig): DataStream[CRow] = { + + val leftDataStream = left.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) + val rightDataStream = right.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) + + // key by all fields + val keys = leftNode.getRowType.getFieldList.asScala.indices.toArray + + val rowSchema = new RowSchema(rowRelDataType).projectedTypeInfo(keys) + + val coFunc = new NonWindowIntersect( + rowSchema.asInstanceOf[RowTypeInfo], + queryConfig, + this.all + ) + + val opName = this.toString + + leftDataStream + .connect(rightDataStream) + .keyBy( + new CRowKeySelector(keys, rowSchema), + new CRowKeySelector(keys, rowSchema)) + .process(coFunc) + .name(opName) + .returns(new CRowTypeInfo(rowSchema.asInstanceOf[RowTypeInfo])) + } + + private def intersectSelectionToString: String = { + getRowType.getFieldNames.asScala.mkString(", ") + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 52dab8b33792c..d33fbd8e3cc05 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -211,6 +211,7 @@ object FlinkRuleSets { DataStreamCorrelateRule.INSTANCE, DataStreamWindowJoinRule.INSTANCE, DataStreamJoinRule.INSTANCE, + DataStreamIntersectRule.INSTANCE, StreamTableSourceScanRule.INSTANCE ) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamIntersectRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamIntersectRule.scala new file mode 100644 index 0000000000000..224c9f41f96bf --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamIntersectRule.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.datastream.DataStreamIntersect +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalIntersect + +import scala.collection.JavaConverters._ + +class DataStreamIntersectRule + extends ConverterRule( + classOf[FlinkLogicalIntersect], + FlinkConventions.LOGICAL, + FlinkConventions.DATASTREAM, + "DataStreamIntersectRule"){ + + override def convert(rel: RelNode): RelNode = { + val intersect: FlinkLogicalIntersect = rel.asInstanceOf[FlinkLogicalIntersect] + val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) + val convLeft: RelNode = RelOptRule.convert(intersect.getInput(0), FlinkConventions.DATASTREAM) + val convRight: RelNode = RelOptRule.convert(intersect.getInput(1), FlinkConventions.DATASTREAM) + + new DataStreamIntersect( + rel.getCluster, + traitSet, + convLeft, + convRight, + rel.getRowType, + intersect.all) + } + + override def matches(call: RelOptRuleCall): Boolean = { + val intersect = call.rel[FlinkLogicalIntersect](0) + // Check that no event-time attributes are in the input + // because non-window intersect is unbounded + // and we don't know how much to hold back watermarks. + val hasRowtimeFields = intersect.getRowType.getFieldList.asScala + .exists(f => FlinkTypeFactory.isRowtimeIndicatorType(f.getType)) + !hasRowtimeFields + } +} + +object DataStreamIntersectRule { + val INSTANCE: RelOptRule = new DataStreamIntersectRule +} + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/setop/NonWindowIntersect.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/setop/NonWindowIntersect.scala new file mode 100644 index 0000000000000..0df447c6871df --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/setop/NonWindowIntersect.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.setop + +import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.api.java.typeutils.TupleTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.table.api.{StreamQueryConfig, Types} +import org.apache.flink.table.runtime.join.CRowWrappingMultiOutputCollector +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.typeutils.TypeCheckUtils.validateEqualsHashCode +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.Collector + +/** + * This function connect left stream and right stream, only used for non-window stream intersect. + * @param resultType the output type of intersect + * @param queryConfig the configuration for the query to generate + * @param all indicates an intersect or intersect all + */ +class NonWindowIntersect( + resultType: TypeInformation[Row], + queryConfig: StreamQueryConfig, + all: Boolean) + extends CoProcessFunction[CRow, CRow, CRow] + with Logging { + + validateEqualsHashCode("intersect", resultType) + + // state to keep track of the left row + private var leftState: ValueState[JTuple2[Int, Long]] = _ + // state to keep track of the right row + private var rightState: ValueState[JTuple2[Int, Long]] = _ + + private val minRetentionTime: Long = queryConfig.getMinIdleStateRetentionTime + private val maxRetentionTime: Long = queryConfig.getMaxIdleStateRetentionTime + private val stateCleaningEnabled: Boolean = minRetentionTime > 1 + + // state to record last timer of left stream, 0 means no timer + private var leftTimer: ValueState[Long] = _ + // state to record last timer of right stream, 0 means no timer + private var rightTimer: ValueState[Long] = _ + + private var cRowWrapper: CRowWrappingMultiOutputCollector = _ + + override def open(parameters: Configuration): Unit = { + LOG.debug("Instantiating StreamIntersectCoProcessFunction.") + // initialize left and right state, the first element of tuple2 indicates how many rows of + // this row, while the second element represents the expired time of this row. + val tupleTypeInfo = new TupleTypeInfo[JTuple2[Int, Long]](Types.INT, Types.LONG) + val leftStateDescriptor = new ValueStateDescriptor[JTuple2[Int, Long]]( + "left", tupleTypeInfo) + val rightStateDescriptor = new ValueStateDescriptor[JTuple2[Int, Long]]( + "right", tupleTypeInfo) + leftState = getRuntimeContext.getState(leftStateDescriptor) + rightState = getRuntimeContext.getState(rightStateDescriptor) + + // initialize timer state + val valueStateDescriptor1 = new ValueStateDescriptor[Long]("leftTimer", classOf[Long]) + leftTimer = getRuntimeContext.getState(valueStateDescriptor1) + val valueStateDescriptor2 = new ValueStateDescriptor[Long]("rightTimer", classOf[Long]) + rightTimer = getRuntimeContext.getState(valueStateDescriptor2) + + cRowWrapper = new CRowWrappingMultiOutputCollector() + //we emit one record per process at most + cRowWrapper.setTimes(1) + } + + override def processElement1( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + + processElement(value, ctx, out, leftState, leftTimer, rightState) + } + + override def processElement2( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + + processElement(value, ctx, out, rightState, rightTimer, leftState) + } + + private def processElement( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow], + currentSideState: ValueState[JTuple2[Int, Long]], + currentSideTimer: ValueState[Long], + otherSideState: ValueState[JTuple2[Int, Long]]): Unit = { + + val inputRow = value.row + cRowWrapper.setChange(value.change) + cRowWrapper.setCollector(out) + + val cntAndExpiredTime = updateState(value, ctx, currentSideState, currentSideTimer) + + val otherSideValue = otherSideState.value() + if (otherSideValue != null) { + if (all) { + if (value.change && cntAndExpiredTime.f0 <= otherSideValue.f0) { + cRowWrapper.collect(inputRow) + } else if (!value.change && cntAndExpiredTime.f0 < otherSideValue.f0) { + cRowWrapper.collect(inputRow) + } + } else { + if (value.change && cntAndExpiredTime.f0 == 1) { + cRowWrapper.collect(inputRow) + } else if (!value.change && cntAndExpiredTime.f0 == 0) { + cRowWrapper.collect(inputRow) + } + } + } + } + + /** + * update valueState and TimerState and return the current state + * @param value + * @param ctx + * @param state + * @param timerState + * @return + */ + private def updateState( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + state: ValueState[JTuple2[Int, Long]], + timerState: ValueState[Long]): JTuple2[Int, Long] = { + + val curProcessTime = ctx.timerService.currentProcessingTime + val oldCntAndExpiredTime = state.value() + val cntAndExpiredTime = if (null == oldCntAndExpiredTime) { + JTuple2.of(0, -1L) + } else { + oldCntAndExpiredTime + } + + cntAndExpiredTime.f1 = getNewExpiredTime(curProcessTime, cntAndExpiredTime.f1) + if (stateCleaningEnabled && timerState.value() == 0) { + timerState.update(cntAndExpiredTime.f1) + ctx.timerService().registerProcessingTimeTimer(cntAndExpiredTime.f1) + } + + if (!value.change) { + cntAndExpiredTime.f0 = cntAndExpiredTime.f0 - 1 + if (cntAndExpiredTime.f0 <= 0) { + state.clear() + } else { + state.update(cntAndExpiredTime) + } + } else { + cntAndExpiredTime.f0 = cntAndExpiredTime.f0 + 1 + state.update(cntAndExpiredTime) + } + cntAndExpiredTime + + } + + def getNewExpiredTime( + curProcessTime: Long, + oldExpiredTime: Long): Long = { + if (stateCleaningEnabled && curProcessTime + minRetentionTime > oldExpiredTime) { + curProcessTime + maxRetentionTime + } else { + oldExpiredTime + } + } + + override def onTimer( + timestamp: Long, + ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, + out: Collector[CRow]): Unit = { + + if (stateCleaningEnabled && leftTimer.value == timestamp) { + expireOutTimeRow( + timestamp, + leftState, + leftTimer, + ctx + ) + } + + if (stateCleaningEnabled && rightTimer.value == timestamp) { + expireOutTimeRow( + timestamp, + rightState, + rightTimer, + ctx + ) + } + } + + private def expireOutTimeRow( + curTime: Long, + rowState: ValueState[JTuple2[Int, Long]], + timerState: ValueState[Long], + ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = { + + var validTimestamp: Boolean = false + val rowValue = rowState.value() + if (rowValue != null) { + val recordExpiredTime = rowValue.f1 + if (recordExpiredTime <= curTime) { + rowState.clear() + } else { + // we found a timestamp that is still valid + validTimestamp = true + } + } + + // If the state has non-expired timestamps, register a new timer. + // Otherwise clean the complete state for this input. + if (validTimestamp) { + val cleanupTime = curTime + maxRetentionTime + ctx.timerService.registerProcessingTimeTimer(cleanupTime) + timerState.update(cleanupTime) + } else { + timerState.clear() + rowState.clear() + } + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/SetOperatorsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/SetOperatorsTest.scala index e84c630f59df8..e923a5836014f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/SetOperatorsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/SetOperatorsTest.scala @@ -228,4 +228,85 @@ class SetOperatorsTest extends TableTestBase { streamUtil.verifyTable(result, expected) } + + @Test + def testFilterIntersectTranspose(): Unit = { + val util = streamTestUtil() + val left = util.addTable[(Int, Long, String)]("left", 'a, 'b, 'c) + val right = util.addTable[(Int, Long, String)]("right", 'a, 'b, 'c) + + val result = left.intersect(right) + .where('a > 0) + .groupBy('b) + .select('a.sum as 'a, 'b as 'b, 'c.count as 'c) + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + binaryNode( + "DataStreamIntersect", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "b", "c"), + term("where", ">(a, 0)") + ), + unaryNode( + "DataStreamCalc", + streamTableNode(1), + term("select", "a", "b", "c"), + term("where", ">(a, 0)") + ), + term("Intersect", "a", "b", "c") + ), + term("groupBy", "b"), + term("select", "b", "SUM(a) AS TMP_0", "COUNT(c) AS TMP_1") + ), + term("select", "TMP_0 AS a", "b", "TMP_1 AS c") + ) + + util.verifyTable(result, expected) + } + + @Test + def testFilterIntersectAllTranspose(): Unit = { + val util = streamTestUtil() + val left = util.addTable[(Int, Long, String)]("left", 'a, 'b, 'c) + val right = util.addTable[(Int, Long, String)]("right", 'a, 'b, 'c) + + val result = left.intersectAll(right) + .where('a > 0) + .groupBy('b) + .select('a.sum as 'a, 'b as 'b, 'c.count as 'c) + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + binaryNode( + "DataStreamIntersect", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "b", "c"), + term("where", ">(a, 0)") + ), + unaryNode( + "DataStreamCalc", + streamTableNode(1), + term("select", "a", "b", "c"), + term("where", ">(a, 0)") + ), + term("IntersectAll", "a", "b", "c") + ), + term("groupBy", "b"), + term("select", "b", "SUM(a) AS TMP_0", "COUNT(c) AS TMP_1") + ), + term("select", "TMP_0 AS a", "b", "TMP_1 AS c") + ) + + util.verifyTable(result, expected) + } + } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/SetOperatorsStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/SetOperatorsStringExpressionTest.scala index bfb3ef63355a1..9ad1878d852cb 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/SetOperatorsStringExpressionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/SetOperatorsStringExpressionTest.scala @@ -46,4 +46,26 @@ class SetOperatorsStringExpressionTest extends TableTestBase { val resJava = t1.unionAll(t2.select("int, long, string")).filter("int < 2").select("int") verifyTableEquals(resJava, resScala) } + + @Test + def testIntersect(): Unit = { + val util = streamTestUtil() + val t1 = util.addTable[(Int, Long, String)]('int, 'long, 'string) + val t2 = util.addTable[(Int, Long, String)]('int, 'long, 'string) + + val resScala = t1.intersect(t2).select('int) + val resJava = t1.intersect(t2).select("int") + verifyTableEquals(resJava, resScala) + } + + @Test + def testIntersectAll(): Unit = { + val util = streamTestUtil() + val t1 = util.addTable[(Int, Long, String)]('int, 'long, 'string) + val t2 = util.addTable[(Int, Long, String)]('int, 'long, 'string) + + val resScala = t1.intersectAll(t2).select('int) + val resJava = t1.intersectAll(t2).select("int") + verifyTableEquals(resJava, resScala) + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/SetOperatorsValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/SetOperatorsValidationTest.scala index 1b20d99e4afa6..a2dd93c44c886 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/SetOperatorsValidationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/SetOperatorsValidationTest.scala @@ -19,9 +19,12 @@ package org.apache.flink.table.api.stream.table.validation import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.{TableEnvironment, ValidationException} +import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException} import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData} import org.apache.flink.table.utils.TableTestBase import org.apache.flink.types.Row @@ -81,4 +84,44 @@ class SetOperatorsValidationTest extends TableTestBase { // Must fail. Tables are bound to different TableEnvironments. ds1.unionAll(ds2) } + + @Test(expected = classOf[TableException]) + def testIntersectWithRowtimeField(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + + StreamITCase.testResults = mutable.MutableList() + + val WatermarkExtractor = new AssignerWithPunctuatedWatermarks[(Int, Long, String)]() { + + override def checkAndGetNextWatermark( + lastElement: (Int, Long, String), + extractedTimestamp: Long): Watermark = { + new Watermark(extractedTimestamp - 1) + } + + override def extractTimestamp( + element: (Int, Long, String), + previousElementTimestamp: Long): Long = { + element._2 + } + } + + val ds1 = StreamTestData + .getSmall3TupleDataStream(env) + .assignTimestampsAndWatermarks(WatermarkExtractor) + .toTable(tEnv, 'a, 'b.rowtime, 'c) + + val ds2 = StreamTestData + .getSmall3TupleDataStream(env) + .assignTimestampsAndWatermarks(WatermarkExtractor) + .toTable(tEnv, 'x, 'y.rowtime, 'z) + + val result = ds1.intersect(ds2) + val results = result.toAppendStream[Row] + results.addSink(new StreamITCase.StringSink[Row]) + env.execute() + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/UnsupportedOpsValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/UnsupportedOpsValidationTest.scala index 91bc780507627..eea33cba9a03f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/UnsupportedOpsValidationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/UnsupportedOpsValidationTest.scala @@ -52,24 +52,6 @@ class UnsupportedOpsValidationTest extends AbstractTestBase { t1.union(t2) } - @Test(expected = classOf[ValidationException]) - def testIntersect(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - val t1 = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv) - val t2 = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv) - t1.intersect(t2) - } - - @Test(expected = classOf[ValidationException]) - def testIntersectAll(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - val t1 = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv) - val t2 = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv) - t1.intersectAll(t2) - } - @Test(expected = classOf[ValidationException]) def testMinus(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/SetOperatorsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/SetOperatorsITCase.scala index d965e0cc09a14..ff57a5cc133a9 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/SetOperatorsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/SetOperatorsITCase.scala @@ -216,8 +216,6 @@ class SetOperatorsITCase( } @Test - @Ignore - // calcite sql parser doesn't support INTERSECT ALL def testIntersectAll(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala index 5c31cb246306a..b9c2aeca5f600 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala @@ -22,15 +22,21 @@ import java.util.concurrent.ConcurrentLinkedQueue import org.apache.flink.api.common.time.Time import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala.typeutils.Types import org.apache.flink.streaming.api.operators.LegacyKeyedProcessOperator +import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator import org.apache.flink.streaming.runtime.streamrecord.StreamRecord +import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.runtime.harness.HarnessTestBase._ +import org.apache.flink.table.runtime.setop.NonWindowIntersect import org.apache.flink.table.runtime.types.CRow import org.apache.flink.types.Row import org.junit.Test + class NonWindowHarnessTest extends HarnessTestBase { protected var queryConfig = @@ -154,4 +160,65 @@ class NonWindowHarnessTest extends HarnessTestBase { testHarness.close() } + + @Test + def testNonWindowIntersect(): Unit = { + + val resultType = new RowTypeInfo( + Types.STRING + ) + + val nonWindowIntersectFunc = new NonWindowIntersect( + resultType, + queryConfig, + all = false + ) + + val operator = new KeyedCoProcessOperator[String, CRow, CRow, CRow](nonWindowIntersectFunc) + + val testHarness = new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow]( + operator, + new TupleRowKeySelector[String](0), + new TupleRowKeySelector[String](0), + Types.STRING, + 1, 1, 0 + ) + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + testHarness.open() + + // register cleanup timer with 3001 + testHarness.setProcessingTime(1L) + testHarness.processElement1(new StreamRecord[CRow](CRow(Row.of("aaa"), change = true))) + testHarness.processElement1(new StreamRecord[CRow](CRow(Row.of("bbb"), change = true))) + testHarness.setProcessingTime(100L) + testHarness.processElement1(new StreamRecord[CRow](CRow(Row.of("ccc"), change = true))) + + testHarness.processElement2(new StreamRecord[CRow](CRow(Row.of("aaa"), change = true))) + + // left 'aaa' intersect with right 'aaa' + expectedOutput.add(new StreamRecord[CRow](CRow(Row.of("aaa"), change = true))) + + // register cleanup timer with 4000 + testHarness.setProcessingTime(1000L) + + testHarness.processElement2(new StreamRecord[CRow](CRow(Row.of("aaa"), change = false))) + expectedOutput.add(new StreamRecord[CRow](CRow(Row.of("aaa"), change = false))) + + // expire left records 'aaa' and 'bbb' + testHarness.setProcessingTime(3002L) + + // this record intersect nothing + testHarness.processElement2(new StreamRecord[CRow](CRow(Row.of("aaa"), change = true))) + // this record intersect nothing + testHarness.processElement2(new StreamRecord[CRow](CRow(Row.of("bbb"), change = true))) + + // right 'ccc' intersect with left 'ccc' + testHarness.processElement2(new StreamRecord[CRow](CRow(Row.of("ccc"), change = true))) + expectedOutput.add(new StreamRecord[CRow](CRow(Row.of("ccc"), change = true))) + + verify(expectedOutput, testHarness.getOutput, new RowResultSortComparator()) + + testHarness.close() + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SetOperatorsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SetOperatorsITCase.scala index 718fd3a0d66a5..8dea69769d3bf 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SetOperatorsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SetOperatorsITCase.scala @@ -18,15 +18,18 @@ package org.apache.flink.table.runtime.stream.sql -import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.scala._ -import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase} +import org.apache.flink.api.scala._ +import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} import org.apache.flink.types.Row import org.junit.Assert.assertEquals import org.junit.Test +import scala.collection.mutable +import scala.util.Random + class SetOperatorsITCase extends StreamingWithStateTestBase { @Test @@ -167,4 +170,124 @@ class SetOperatorsITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } + + @Test + def testIntersect(): Unit = { + val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + StreamITCase.clear + val sqlQuery = "SELECT c FROM t1 INTERSECT SELECT c FROM t2" + + val ds1 = StreamTestData.getSmall3TupleDataStream(env) + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Hi")) + data.+=((2, 2L, "Hello")) + data.+=((2, 2L, "Hello")) + data.+=((3, 2L, "Hello world!")) + val ds2 = env.fromCollection(Random.shuffle(data)) + + tEnv.registerTable("t1", ds1.toTable(tEnv, 'a, 'b, 'c)) + tEnv.registerTable("t2", ds2.toTable(tEnv, 'a, 'b, 'c)) + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("Hi", "Hello") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testIntersectAll(): Unit = { + val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + StreamITCase.clear + val sqlQuery = "SELECT c FROM t1 INTERSECT ALL SELECT c FROM t2" + + val data1 = new mutable.MutableList[Int] + data1 += (1, 1, 1, 2, 2, 3, 3, 3, 4) + val data2 = new mutable.MutableList[Int] + data2 += (1, 2, 2, 3, 3, 3, 4, 4, 4, 4) + val ds1 = env.fromCollection(data1) + val ds2 = env.fromCollection(data2) + + tEnv.registerTable("t1", ds1.toTable(tEnv, 'c)) + tEnv.registerTable("t2", ds2.toTable(tEnv, 'c)) + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("1", "2", "2", "3", "3", "3", "4") + + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testIntersectWithFilter(): Unit = { + val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + StreamITCase.clear + + val sqlQuery = "SELECT c FROM ((SELECT * FROM t1) INTERSECT (SELECT * FROM t2)) WHERE a > 1" + + val ds1 = StreamTestData.getSmall3TupleDataStream(env) + val ds2 = StreamTestData.get3TupleDataStream(env) + + tEnv.registerTable("t1", ds1.toTable(tEnv, 'a, 'b, 'c)) + tEnv.registerTable("t2", ds2.toTable(tEnv, 'a, 'b, 'c)) + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("Hello", "Hello world") + + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testIntersectWithRetraction(): Unit = { + val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + StreamITCase.clear + val sqlQuery = "SELECT SUM(b) FROM t1 GROUP BY a INTERSECT SELECT SUM(b) FROM t2 GROUP BY a" + + val data1 = List( + ("a", 1), + ("a", 2), + ("a", 3), + ("a", 4), + ("b", 10), + ("b", -6), + ("b", -1) + ) + val data2 = List( + ("a", 3), + ("a", 3), + ("a", 3), + ("a", 1), + ("b", 1), + ("b", 1), + ("b", 1) + ) + + val ds1 = env.fromCollection(data1) + val ds2 = env.fromCollection(data2) + + tEnv.registerTable("t1", ds1.toTable(tEnv, 'a, 'b)) + tEnv.registerTable("t2", ds2.toTable(tEnv, 'a, 'b)) + + val result = tEnv.sqlQuery(sqlQuery).toRetractStream[Row] + result.addSink(new StreamITCase.RetractingSink) + env.execute() + + val expected = List("10", "3") + + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/SetOperatorsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/SetOperatorsITCase.scala index f01c191f4e0c1..cbbdd8f37bd4e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/SetOperatorsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/SetOperatorsITCase.scala @@ -30,6 +30,7 @@ import org.junit.Assert._ import org.junit.Test import scala.collection.mutable +import scala.util.Random class SetOperatorsITCase extends AbstractTestBase { @@ -228,4 +229,89 @@ class SetOperatorsITCase extends AbstractTestBase { assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } + + def testIntersect(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + StreamITCase.clear + + val ds1 = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Hi")) + data.+=((2, 2L, "Hello")) + data.+=((2, 2L, "Hello")) + data.+=((3, 2L, "Hello world!")) + val ds2 = env.fromCollection(Random.shuffle(data)).toTable(tEnv, 'a, 'b, 'c) + + val result = ds1.intersect(ds2).select('c).toAppendStream[Row] + + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("Hi", "Hello") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testIntersectAll(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + StreamITCase.clear + + val data1 = new mutable.MutableList[Int] + data1 += (1, 1, 1, 2, 2, 3, 3, 3, 4) + val data2 = new mutable.MutableList[Int] + data2 += (1, 2, 2, 3, 3, 3, 4, 4, 4, 4) + val ds1 = env.fromCollection(data1).toTable(tEnv, 'c) + val ds2 = env.fromCollection(data2).toTable(tEnv, 'c) + + val result = ds1.intersectAll(ds2).select('c).toAppendStream[Row] + + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("1", "2", "2", "3", "3", "3", "4") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testIntersectWithDifferentFieldNames(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + StreamITCase.clear + + val ds1 = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'e, 'f, 'g) + + val result = ds1.intersect(ds2).select('c).toAppendStream[Row] + + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("Hi", "Hello", "Hello world") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testIntersectWithScalarExpression(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + StreamITCase.clear + + val ds1 = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .select('a + 1, 'b, 'c) + val ds2 = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .select('a + 1, 'b, 'c) + + val result = ds1.intersect(ds2).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List("2,1,Hi", "3,2,Hello", "4,2,Hello world") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } }